diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 3e686b1c415e61a24e0f6729555e672721cf806f..7f56737ca99906ad7a62ec2fb8c4a6c1f3d5a221 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -229,6 +229,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(conv_transpose_op DEPS vol2col) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op DEPS executor) +op_library(cos_sim_op DEPS cos_sim_functor) # FIXME(typhoonzero): save/load depends lodtensor serialization functions op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index e2b6282c0913e8ad16f8e3f6c3054f9567822d15..eadcca55f9bfc3e59f329df8ff419ad4c5a29007 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -13,19 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/cos_sim_functor.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; -template -using EigenVector = framework::EigenVector; template class CosSimKernel : public framework::OpKernel { @@ -41,28 +37,25 @@ class CosSimKernel : public framework::OpKernel { out_x_norm->mutable_data(context.GetPlace()); out_y_norm->mutable_data(context.GetPlace()); - // convert Tensor to Eigen Tensor int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; - auto x = EigenMatrix::Reshape(*in_x, 1); - auto y = EigenMatrix::Reshape(*in_y, 1); - auto z = EigenVector::Flatten(*out_z); - auto x_norm = EigenVector::Flatten(*out_x_norm); - auto y_norm = EigenVector::Flatten(*out_y_norm); - // compute - auto& place = - *context.template device_context().eigen_device(); - auto row_along = Eigen::array({{1}}); - x_norm.device(place) = x.square().sum(row_along).sqrt(); - y_norm.device(place) = y.square().sum(row_along).sqrt(); + int cols = framework::product(in_x->dims()) / rows_x; + if (rows_x == rows_y) { - auto xy = (x * y).sum(Eigen::array({{1}})); - z.device(place) = xy / x_norm / y_norm; + math::CosSimFunctor functor( + in_x->data(), in_y->data(), out_x_norm->data(), + out_y_norm->data(), out_z->data(), cols); + platform::ForRange for_range( + static_cast(context.device_context()), rows_x); + for_range(functor); } else { - Eigen::DSizes bcast(rows_x, 1); - auto xy = (x * y.broadcast(bcast)).sum(row_along); - z.device(place) = xy / x_norm / y_norm.broadcast(bcast); + math::CosSimFunctor functor( + in_x->data(), in_y->data(), out_x_norm->data(), + out_y_norm->data(), out_z->data(), cols); + platform::ForRange for_range( + static_cast(context.device_context()), rows_x); + for_range(functor); } } }; @@ -81,62 +74,54 @@ class CosSimGradKernel : public framework::OpKernel { auto* out_grad_y = context.Output(framework::GradVarName("Y")); auto* in_grad_z = context.Input(framework::GradVarName("Out")); - // convert Tensor to Eigen Tensor - auto x = EigenMatrix::Reshape(*in_x, 1); - auto y = EigenMatrix::Reshape(*in_y, 1); - auto z = EigenMatrix::Reshape(*in_z, 1); - auto x_norm = EigenMatrix::Reshape(*in_x_norm, 1); - auto y_norm = EigenMatrix::Reshape(*in_y_norm, 1); - auto dz = EigenMatrix::Reshape(*in_grad_z, 1); - // compute gradident int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; int cols = framework::product(in_x->dims()) / rows_x; - Eigen::DSizes bcast_cols(1, cols); - auto z_bcast = z.broadcast(bcast_cols); - auto dz_bcast = dz.broadcast(bcast_cols); - auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols); - auto& place = - *context.template device_context().eigen_device(); + if (rows_x == rows_y) { - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols); - auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols); - // compute dx if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::Reshape(*out_grad_x, 1); - auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast; - dx.device(place) = dz_bcast * grad; + math::CosSimGradFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_x->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } - // compute dy if (out_grad_y) { - out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenMatrix::Reshape(*out_grad_y, 1); - auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast; - dy.device(place) = dz_bcast * grad; + math::CosSimGradFunctor functor( + in_y_norm->data(), in_x_norm->data(), in_y->data(), + in_x->data(), in_z->data(), in_grad_z->data(), + out_grad_y->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } } else { - Eigen::DSizes bcast_rows(rows_x, 1); - Eigen::DSizes bcast_rows_cols(rows_x, cols); - auto y_bcast = y.broadcast(bcast_rows); - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols); - auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows)) - .eval() - .broadcast(bcast_cols); - // compute dx if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::Reshape(*out_grad_x, 1); - auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast; - dx.device(place) = dz_bcast * grad; + math::CosSimDxFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_x->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } - // compute dy if (out_grad_y) { out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenVector::Flatten(*out_grad_y); - auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast; - dy.device(place) = (dz_bcast * grad).sum(Eigen::array({{0}})); + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, out_grad_y, static_cast(0)); + + math::CosSimDyFunctor functor; + functor(dev_ctx, in_x_norm->data(), in_y_norm->data(), + in_x->data(), in_y->data(), in_z->data(), + in_grad_z->data(), static_cast(rows_x), + static_cast(cols), out_grad_y->data()); } } } diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index b97faec4ed687c1cf8d746cdf615e86fd79ca921..7ebcfb9ab9f30e3b0f13d3646a59d008335b232d 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -16,6 +16,7 @@ if(WITH_GPU) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) + nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -30,6 +31,7 @@ else() cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) + cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/cos_sim_functor.cc b/paddle/operators/math/cos_sim_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..f52a82b10870205ff490b2bf2187a2ada1afe5e8 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/cos_sim_functor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CosSimDyFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + for (size_t row_id = 0; row_id < rows; ++row_id) { + auto xy_norm_prod = x_norm[row_id] * y_norm[0]; + auto dz_data = dz[row_id]; + auto z_data = z[row_id]; + auto* x_data = x + cols * row_id; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + auto y_norm_square = y_norm[0] * y_norm[0]; + auto reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + } + } + } +}; + +template class CosSimDyFunctor; +template class CosSimDyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cos_sim_functor.cu b/paddle/operators/math/cos_sim_functor.cu new file mode 100644 index 0000000000000000000000000000000000000000..fb19a8b38a44a33728122ab87f99c118b34bb973 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.cu @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/cos_sim_functor.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x, + const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) { + int grid_size = blockDim.x * gridDim.x; + T y_norm_data = y_norm[0]; + for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows; + row_id += grid_size) { + T xy_norm_prod = x_norm[row_id] * y_norm_data; + T dz_data = dz[row_id]; + T z_data = z[row_id]; + const T* x_data = x + cols * row_id; + T reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + T y_norm_square = y_norm_data * y_norm_data; + T reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + platform::CudaAtomicAdd(dy + i, dy_data); + } + } +} + +template +struct CosSimDyFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + const int block_size = 512; + dim3 threads(block_size, 1); + dim3 grid(1, (rows + block_size - 1) / block_size); + CosSimDyKernel<<>>( + x_norm, y_norm, x, y, z, dz, rows, cols, dy); + } +}; + +template class CosSimDyFunctor; +template class CosSimDyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cos_sim_functor.h b/paddle/operators/math/cos_sim_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..aae8ab5b7a937c016e8a45e34b22aba7a1df3066 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.h @@ -0,0 +1,166 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CosSimFunctor { + CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto* x = x_ + cols_ * row_id; + T xx = 0, xy = 0, yy = 0; + if (same_row) { + auto* y = y_ + cols_ * row_id; + T tep_x, tep_y; + for (size_t i = 0; i < cols_; ++i) { + tep_x = x[i]; + tep_y = y[i]; + xx += tep_x * tep_x; + yy += tep_y * tep_y; + xy += tep_x * tep_y; + } + xx = sqrt(xx); + yy = sqrt(yy); + y_norm_[row_id] = yy; + x_norm_[row_id] = xx; + z_[row_id] = xy / (xx * yy); + } else { // This can be wrote in a better way. + T tep_x, tep_y; + for (size_t i = 0; i < cols_; ++i) { + tep_x = x[i]; + tep_y = y_[i]; + xx += tep_x * tep_x; + yy += tep_y * tep_y; + xy += tep_x * tep_y; + } + xx = sqrt(xx); + yy = sqrt(yy); + if (row_id == 0) y_norm_[0] = yy; + x_norm_[row_id] = xx; + z_[row_id] = xy / (xx * yy); + } + } + + T* x_norm_; + T* y_norm_; + const T* x_; + const T* y_; + T* z_; + const size_t cols_; +}; + +template +struct CosSimGradFunctor { + CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; + auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id]; + auto dz = dz_[row_id]; + auto z = z_[row_id]; + + auto* dx = dx_ + cols_ * row_id; + auto* x = x_ + cols_ * row_id; + auto* y = y_ + cols_ * row_id; + + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto reciprocal_x_norm_square = 1 / x_norm_square; + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); + } + } + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDxFunctor { + CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto xy_norm_prod = x_norm_[row_id] * y_norm_[0]; + auto dz = dz_[row_id]; + auto z = z_[row_id]; + auto* x = x_ + cols_ * row_id; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; + auto* dx = dx_ + cols_ * row_id; + auto reciprocal_x_norm_square = 1 / x_norm_square; + + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); + } + } + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDyFunctor { + void operator()(const DeviceContext& ctx, const T* x_norm, const T* y_norm, + const T* x, const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle