From 24cf2fcd90a8409da2e5e38118c73eb4af13121f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 29 Dec 2017 15:16:49 +0800 Subject: [PATCH] move cos_sim_functor to math --- paddle/operators/CMakeLists.txt | 4 +- paddle/operators/cos_sim_op.cc | 22 --- paddle/operators/cos_sim_op.cu | 45 ------ paddle/operators/cos_sim_op.h | 153 +-------------------- paddle/operators/math/CMakeLists.txt | 2 + paddle/operators/math/cos_sim_functor.cc | 48 +++++++ paddle/operators/math/cos_sim_functor.cu | 64 +++++++++ paddle/operators/math/cos_sim_functor.h | 166 +++++++++++++++++++++++ 8 files changed, 290 insertions(+), 214 deletions(-) create mode 100644 paddle/operators/math/cos_sim_functor.cc create mode 100644 paddle/operators/math/cos_sim_functor.cu create mode 100644 paddle/operators/math/cos_sim_functor.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 5aaaf993323..c6da04b5b40 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -210,7 +210,8 @@ set(DEPS_OPS save_op load_op send_op - recv_op) + recv_op + cos_sim_op) if(WITH_DISTRIBUTE) add_subdirectory(detail) @@ -256,6 +257,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(conv_transpose_op DEPS vol2col) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op SRCS recurrent_op.cc DEPS executor) +op_library(cos_sim_op DEPS cos_sim_functor) # FIXME(typhoonzero): save/load depends lodtensor serialization functions op_library(save_op DEPS lod_tensor) diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index d4f3ca5e32c..9019a1edb37 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -149,28 +149,6 @@ class CosSimOpGrad : public framework::OperatorWithKernel { } }; -template -struct CosSimDyFunctor { - inline void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm, - const T* y_norm, const T* x, const T* y, const T* z, - const T* dz, const size_t rows, const size_t cols, - T* dy) const { - for (size_t row_id = 0; row_id < rows; ++row_id) { - auto xy_norm_prod = x_norm[row_id] * y_norm[0]; - auto dz_data = dz[row_id]; - auto z_data = z[row_id]; - auto* x_data = x + cols * row_id; - auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; - - auto y_norm_square = y_norm[0] * y_norm[0]; - auto reciprocal_y_norm_square = 1 / y_norm_square; - for (size_t i = 0; i < cols; ++i) { - dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod - - z_data * y[i] * reciprocal_y_norm_square); - } - } - } -}; } // namespace operators } // namespace paddle diff --git a/paddle/operators/cos_sim_op.cu b/paddle/operators/cos_sim_op.cu index 891436c9483..9e5d1b6e4f0 100644 --- a/paddle/operators/cos_sim_op.cu +++ b/paddle/operators/cos_sim_op.cu @@ -14,51 +14,6 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/operators/cos_sim_op.h" -#include "paddle/platform/cuda_helper.h" - -namespace paddle { -namespace operators { - -template -__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x, - const T* y, const T* z, const T* dz, - const size_t rows, const size_t cols, T* dy) { - int grid_size = blockDim.x * gridDim.x; - T y_norm_data = y_norm[0]; - for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows; - row_id += grid_size) { - T xy_norm_prod = x_norm[row_id] * y_norm_data; - T dz_data = dz[row_id]; - T z_data = z[row_id]; - const T* x_data = x + cols * row_id; - T reciprocal_xy_norm_prod = 1 / xy_norm_prod; - - T y_norm_square = y_norm_data * y_norm_data; - T reciprocal_y_norm_square = 1 / y_norm_square; - for (size_t i = 0; i < cols; ++i) { - T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod - - z_data * y[i] * reciprocal_y_norm_square); - platform::CudaAtomicAdd(dy + i, dy_data); - } - } -} - -template -struct CosSimDyFunctor { - inline void operator()(const platform::CUDADeviceContext& ctx, - const T* x_norm, const T* y_norm, const T* x, - const T* y, const T* z, const T* dz, const size_t rows, - const size_t cols, T* dy) const { - const int block_size = 512; - dim3 threads(block_size, 1); - dim3 grid(1, (rows + block_size - 1) / block_size); - CosSimDyKernel<<>>( - x_norm, y_norm, x, y, z, dz, rows, cols, dy); - } -}; - -} // namespace operators -} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index 160edb0b561..eadcca55f9b 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/cos_sim_functor.h" #include "paddle/operators/math/math_function.h" #include "paddle/platform/for_range.h" @@ -22,59 +23,6 @@ namespace operators { using Tensor = framework::Tensor; -template -struct CosSimFunctor { - CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols) - : x_norm_(x_norm), - y_norm_(y_norm), - x_(x), - y_(y), - z_(z), - cols_(static_cast(cols)) {} - - inline HOSTDEVICE void operator()(size_t row_id) const { - auto* x = x_ + cols_ * row_id; - T xx = 0, xy = 0, yy = 0; - if (same_row) { - auto* y = y_ + cols_ * row_id; - T tep_x, tep_y; - for (size_t i = 0; i < cols_; ++i) { - tep_x = x[i]; - tep_y = y[i]; - xx += tep_x * tep_x; - yy += tep_y * tep_y; - xy += tep_x * tep_y; - } - xx = sqrt(xx); - yy = sqrt(yy); - y_norm_[row_id] = yy; - x_norm_[row_id] = xx; - z_[row_id] = xy / (xx * yy); - } else { // This can be wrote in a better way. - T tep_x, tep_y; - for (size_t i = 0; i < cols_; ++i) { - tep_x = x[i]; - tep_y = y_[i]; - xx += tep_x * tep_x; - yy += tep_y * tep_y; - xy += tep_x * tep_y; - } - xx = sqrt(xx); - yy = sqrt(yy); - if (row_id == 0) y_norm_[0] = yy; - x_norm_[row_id] = xx; - z_[row_id] = xy / (xx * yy); - } - } - - T* x_norm_; - T* y_norm_; - const T* x_; - const T* y_; - T* z_; - const size_t cols_; -}; - template class CosSimKernel : public framework::OpKernel { public: @@ -95,14 +43,14 @@ class CosSimKernel : public framework::OpKernel { int cols = framework::product(in_x->dims()) / rows_x; if (rows_x == rows_y) { - CosSimFunctor functor( + math::CosSimFunctor functor( in_x->data(), in_y->data(), out_x_norm->data(), out_y_norm->data(), out_z->data(), cols); platform::ForRange for_range( static_cast(context.device_context()), rows_x); for_range(functor); } else { - CosSimFunctor functor( + math::CosSimFunctor functor( in_x->data(), in_y->data(), out_x_norm->data(), out_y_norm->data(), out_z->data(), cols); platform::ForRange for_range( @@ -112,93 +60,6 @@ class CosSimKernel : public framework::OpKernel { } }; -template -struct CosSimGradFunctor { - CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, - const T* z, const T* dz, T* dx, int cols) - : x_norm_(x_norm), - y_norm_(y_norm), - x_(x), - y_(y), - z_(z), - dz_(dz), - dx_(dx), - cols_(static_cast(cols)) {} - - inline HOSTDEVICE void operator()(size_t row_id) const { - auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; - auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id]; - auto dz = dz_[row_id]; - auto z = z_[row_id]; - - auto* dx = dx_ + cols_ * row_id; - auto* x = x_ + cols_ * row_id; - auto* y = y_ + cols_ * row_id; - - auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; - auto reciprocal_x_norm_square = 1 / x_norm_square; - for (size_t i = 0; i < cols_; ++i) { - dx[i] = dz * (y[i] * reciprocal_xy_norm_prod - - z * x[i] * reciprocal_x_norm_square); - } - } - - const T* x_norm_; - const T* y_norm_; - const T* x_; - const T* y_; - const T* z_; - const T* dz_; - T* dx_; - const size_t cols_; -}; - -template -struct CosSimDxFunctor { - CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, - const T* z, const T* dz, T* dx, int cols) - : x_norm_(x_norm), - y_norm_(y_norm), - x_(x), - y_(y), - z_(z), - dz_(dz), - dx_(dx), - cols_(static_cast(cols)) {} - - inline HOSTDEVICE void operator()(size_t row_id) const { - auto xy_norm_prod = x_norm_[row_id] * y_norm_[0]; - auto dz = dz_[row_id]; - auto z = z_[row_id]; - auto* x = x_ + cols_ * row_id; - auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; - auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; - auto* dx = dx_ + cols_ * row_id; - auto reciprocal_x_norm_square = 1 / x_norm_square; - - for (size_t i = 0; i < cols_; ++i) { - dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod - - z * x[i] * reciprocal_x_norm_square); - } - } - const T* x_norm_; - const T* y_norm_; - const T* x_; - const T* y_; - const T* z_; - const T* dz_; - T* dx_; - const size_t cols_; -}; - -template -struct CosSimDyFunctor { - inline void operator()(const DeviceContext& ctx, const T* x_norm, - const T* y_norm, const T* x, const T* y, const T* z, - const T* dz, const size_t rows, const size_t cols, - T* dy) const; -}; - template class CosSimGradKernel : public framework::OpKernel { public: @@ -220,7 +81,7 @@ class CosSimGradKernel : public framework::OpKernel { if (rows_x == rows_y) { if (out_grad_x) { - CosSimGradFunctor functor( + math::CosSimGradFunctor functor( in_x_norm->data(), in_y_norm->data(), in_x->data(), in_y->data(), in_z->data(), in_grad_z->data(), out_grad_x->mutable_data(context.GetPlace()), cols); @@ -230,7 +91,7 @@ class CosSimGradKernel : public framework::OpKernel { for_range(functor); } if (out_grad_y) { - CosSimGradFunctor functor( + math::CosSimGradFunctor functor( in_y_norm->data(), in_x_norm->data(), in_y->data(), in_x->data(), in_z->data(), in_grad_z->data(), out_grad_y->mutable_data(context.GetPlace()), cols); @@ -241,7 +102,7 @@ class CosSimGradKernel : public framework::OpKernel { } } else { if (out_grad_x) { - CosSimDxFunctor functor( + math::CosSimDxFunctor functor( in_x_norm->data(), in_y_norm->data(), in_x->data(), in_y->data(), in_z->data(), in_grad_z->data(), out_grad_x->mutable_data(context.GetPlace()), cols); @@ -256,7 +117,7 @@ class CosSimGradKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); set_zero(dev_ctx, out_grad_y, static_cast(0)); - CosSimDyFunctor functor; + math::CosSimDyFunctor functor; functor(dev_ctx, in_x_norm->data(), in_y_norm->data(), in_x->data(), in_y->data(), in_z->data(), in_grad_z->data(), static_cast(rows_x), diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index bf47879f772..830ae53cbe9 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -16,6 +16,7 @@ if(WITH_GPU) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) + nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -30,6 +31,7 @@ else() cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) + cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/cos_sim_functor.cc b/paddle/operators/math/cos_sim_functor.cc new file mode 100644 index 00000000000..f52a82b1087 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/cos_sim_functor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CosSimDyFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + for (size_t row_id = 0; row_id < rows; ++row_id) { + auto xy_norm_prod = x_norm[row_id] * y_norm[0]; + auto dz_data = dz[row_id]; + auto z_data = z[row_id]; + auto* x_data = x + cols * row_id; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + auto y_norm_square = y_norm[0] * y_norm[0]; + auto reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + } + } + } +}; + +template class CosSimDyFunctor; +template class CosSimDyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cos_sim_functor.cu b/paddle/operators/math/cos_sim_functor.cu new file mode 100644 index 00000000000..fb19a8b38a4 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.cu @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/cos_sim_functor.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x, + const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) { + int grid_size = blockDim.x * gridDim.x; + T y_norm_data = y_norm[0]; + for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows; + row_id += grid_size) { + T xy_norm_prod = x_norm[row_id] * y_norm_data; + T dz_data = dz[row_id]; + T z_data = z[row_id]; + const T* x_data = x + cols * row_id; + T reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + T y_norm_square = y_norm_data * y_norm_data; + T reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + platform::CudaAtomicAdd(dy + i, dy_data); + } + } +} + +template +struct CosSimDyFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + const int block_size = 512; + dim3 threads(block_size, 1); + dim3 grid(1, (rows + block_size - 1) / block_size); + CosSimDyKernel<<>>( + x_norm, y_norm, x, y, z, dz, rows, cols, dy); + } +}; + +template class CosSimDyFunctor; +template class CosSimDyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cos_sim_functor.h b/paddle/operators/math/cos_sim_functor.h new file mode 100644 index 00000000000..aae8ab5b7a9 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.h @@ -0,0 +1,166 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CosSimFunctor { + CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto* x = x_ + cols_ * row_id; + T xx = 0, xy = 0, yy = 0; + if (same_row) { + auto* y = y_ + cols_ * row_id; + T tep_x, tep_y; + for (size_t i = 0; i < cols_; ++i) { + tep_x = x[i]; + tep_y = y[i]; + xx += tep_x * tep_x; + yy += tep_y * tep_y; + xy += tep_x * tep_y; + } + xx = sqrt(xx); + yy = sqrt(yy); + y_norm_[row_id] = yy; + x_norm_[row_id] = xx; + z_[row_id] = xy / (xx * yy); + } else { // This can be wrote in a better way. + T tep_x, tep_y; + for (size_t i = 0; i < cols_; ++i) { + tep_x = x[i]; + tep_y = y_[i]; + xx += tep_x * tep_x; + yy += tep_y * tep_y; + xy += tep_x * tep_y; + } + xx = sqrt(xx); + yy = sqrt(yy); + if (row_id == 0) y_norm_[0] = yy; + x_norm_[row_id] = xx; + z_[row_id] = xy / (xx * yy); + } + } + + T* x_norm_; + T* y_norm_; + const T* x_; + const T* y_; + T* z_; + const size_t cols_; +}; + +template +struct CosSimGradFunctor { + CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; + auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id]; + auto dz = dz_[row_id]; + auto z = z_[row_id]; + + auto* dx = dx_ + cols_ * row_id; + auto* x = x_ + cols_ * row_id; + auto* y = y_ + cols_ * row_id; + + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto reciprocal_x_norm_square = 1 / x_norm_square; + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); + } + } + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDxFunctor { + CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto xy_norm_prod = x_norm_[row_id] * y_norm_[0]; + auto dz = dz_[row_id]; + auto z = z_[row_id]; + auto* x = x_ + cols_ * row_id; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; + auto* dx = dx_ + cols_ * row_id; + auto reciprocal_x_norm_square = 1 / x_norm_square; + + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); + } + } + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDyFunctor { + void operator()(const DeviceContext& ctx, const T* x_norm, const T* y_norm, + const T* x, const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle -- GitLab