diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 40828dd5cc76f4197e6cfbb1121f2eef2c1ac580..6f21b82afdc6cdde785fdd8f13eef17a0fdd6324 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -188,48 +188,6 @@ extern void hl_param_relu_backward_diff(real* grad_o, int width, int height, int partial_sum); -/** - * @brief cos sim forward - * - * @param[out] output output data - * @param[in] input1 input1 data(matrix) - * @param[in] input2 input2 data(matrix or vector) - * @param[in] width matrix width - * @param[in] input1_height input1_height - * @param[in] input2_height input2_height - * @param[in] scale scale factor - */ -extern void hl_cossim(real* output, - real* input1, - real* input2, - int width, - int input1_height, - int input2_height, - real scale); -/** - * @brief cos sim derivate - * - * @param[in] grad output grad - * @param[in] output output data - * @param[in] prevOutX input1 data - * @param[in] prevOutY input2 data - * @param[out] prevGradX input1 grad - * @param[out] prevGradY input2 grad - * @param[in] width matrix width - * @param[in] input1_height input1 height - * @param[in] input2_height input2 height - * @param[in] scale scale factor - */ -extern void hl_cossim_derivative(real* grad, - real* output, - real* prevOutX, - real* prevOutY, - real* prevGradX, - real* prevGradY, - int width, - int input1_height, - int input2_height, - real scale); /** * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel]. diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index a1712d1e4d2a5dc80526b7d7b5ad7bd4f5d8c1ed..f4e6461cdcf198637b2c96fee88d1de2766aaf18 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -74,25 +74,6 @@ inline void hl_param_relu_backward_diff(real* grad_o, int height, int partial_sum) {} -inline void hl_cossim(real* output, - real* input1, - real* input2, - int width, - int input1_height, - int input2_height, - real scale) {} - -inline void hl_cossim_derivative(real* grad, - real* output, - real* prevOutX, - real* prevOutY, - real* prevGradX, - real* prevGradY, - int width, - int input1_height, - int input2_height, - real scale) {} - inline void hl_matrix_add_shared_bias(real* A_d, real* B_d, const int channel, diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index cd23bd31057c5c8cd10173bc5fa5fa67f2d0e422..96c07d9c3b7a37daa9198fd7ea66b7d811600348 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -584,177 +584,6 @@ void hl_param_relu_backward_diff(real* grad_o, CHECK_SYNC("hl_param_relu_backward_diff failed"); } -template -__global__ void KeCosSim(real* output, - real* input1, - real* input2, - int width, - int input1_height, - int input2_height, - real scale) { - const int ty = blockIdx.y; - int tid = threadIdx.x; - - __shared__ real xx[blockSize]; - __shared__ real yy[blockSize]; - __shared__ real xy[blockSize]; - - xx[tid] = 0.0; - yy[tid] = 0.0; - xy[tid] = 0.0; - __syncthreads(); - - input1 += ty * width; - if (input2_height > 1) { - input2 += ty * width; - } - for (int index = tid; index < width; index += blockSize) { - real x = input1[index]; - real y = input2[index]; - xx[tid] += x * x; - yy[tid] += y * y; - xy[tid] += x * y; - } - __syncthreads(); - - for (int s = blockSize / 2; s > 0; s >>= 1) { - if (tid < s) { - xx[tid] += xx[tid + s]; - yy[tid] += yy[tid + s]; - xy[tid] += xy[tid + s]; - } - __syncthreads(); - } - if (tid == 0) { - output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0])); - } -} - -void hl_cossim(real* output, - real* input1, - real* input2, - int width, - int input1_height, - int input2_height, - real scale) { - CHECK_NOTNULL(output); - CHECK_NOTNULL(input1); - CHECK_NOTNULL(input2); - const int blockSize = 256; - dim3 threads(blockSize, 1); - dim3 grid(1, input1_height); - - KeCosSim<<>> - (output, input1, input2, width, input1_height, input2_height, scale); - CHECK_SYNC("hl_cossim failed"); -} - -template -__global__ void KeCosSimDerivative(real* grad, - real* output, - real* prevOutX, - real* prevOutY, - real* prevGradX, - real* prevGradY, - int width, - int input1_height, - int input2_height, - real scale) { - const int ty = blockIdx.y; - int tid = threadIdx.x; - - __shared__ real xx[blockSize]; - __shared__ real yy[blockSize]; - __shared__ real xy[blockSize]; - - xx[tid] = 0.0; - yy[tid] = 0.0; - xy[tid] = 0.0; - __syncthreads(); - - prevOutX += ty * width; - prevGradX += ty * width; - if (input2_height > 1) { - prevOutY += ty * width; - prevGradY += ty * width; - } - for (int index = tid; index < width; index += blockSize) { - real x = prevOutX[index]; - real y = prevOutY[index]; - xx[tid] += x * x; - yy[tid] += y * y; - xy[tid] += x * y; - } - __syncthreads(); - - for (int s = blockSize / 2; s > 0; s >>= 1) { - if (tid < s) { - xx[tid] += xx[tid + s]; - yy[tid] += yy[tid + s]; - xy[tid] += xy[tid + s]; - } - __syncthreads(); - } - if (xy[0] == 0) { - real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0])); - for (int index = tid; index < width; index += blockSize) { - prevGradX[index] += - scale * grad[ty] * prevOutY[index] * reciprocal; - if (input2_height > 1) { - prevGradY[index] += - scale * grad[ty] * prevOutX[index] * reciprocal; - } else { - paddle::paddleAtomicAdd(prevGradY + index, - scale * grad[ty] * prevOutX[index] * reciprocal); - } - } - } else { - real reciprocalXY = 1.0 / xy[0]; - real reciprocalSquareSumX = 1.0 / xx[0]; - real reciprocalSquareSumY = 1.0 / yy[0]; - for (int index = tid; index < width; index += blockSize) { - prevGradX[index] += output[ty] * grad[ty] * - (prevOutY[index] * reciprocalXY - - prevOutX[index] * reciprocalSquareSumX); - if (input2_height > 1) { - prevGradY[index] += output[ty] * grad[ty] * - (prevOutX[index] * reciprocalXY - - prevOutY[index] * reciprocalSquareSumY); - } else { - paddle::paddleAtomicAdd(prevGradY + index, output[ty] * grad[ty] * - (prevOutX[index] * reciprocalXY - - prevOutY[index] * reciprocalSquareSumY)); - } - } - } -} - - -void hl_cossim_derivative(real* grad, - real* output, - real* prevOutX, - real* prevOutY, - real* prevGradX, - real* prevGradY, - int width, - int input1_height, - int input2_height, - real scale) { - CHECK_NOTNULL(grad); - CHECK_NOTNULL(output); - CHECK_NOTNULL(prevOutX); - CHECK_NOTNULL(prevOutY); - CHECK_NOTNULL(prevGradX); - CHECK_NOTNULL(prevGradY); - const int blockSize = 256; - dim3 threads(blockSize, 1); - dim3 grid(1, input1_height); - KeCosSimDerivative<<>> - (grad, output, prevOutX, prevOutY, prevGradX, prevGradY, width, - input1_height, input2_height, scale); - CHECK_SYNC("hl_cossim_derivate failed"); -} - __global__ void KeMatrixAddSharedBias(real* A, real* B, const int channel, diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index fae3b7b20a70b56dc44ea2df637281afe01a7e5a..1522510e8bb9816cb468fcf406e22560163950cc 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -27,6 +27,7 @@ if(WITH_TESTING) add_simple_unittest(ContextProjectionOpTest) add_simple_unittest(PadOpTest) add_simple_unittest(MulOpTest) + add_simple_unittest(CosSimOpTest) endif() endif() diff --git a/paddle/function/CosSimOp.cpp b/paddle/function/CosSimOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ece7b2dfedaf460741c97b5a700eb632d85cabc --- /dev/null +++ b/paddle/function/CosSimOp.cpp @@ -0,0 +1,240 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CosSimOp.h" +#include "paddle/math/Matrix.h" +#include "paddle/math/Vector.h" + +namespace paddle { +/** + * Cosine Similarity for CpuMatrix + * + * \param out_mat, output value, size: nSamples * 1. + * \param in1_mat, input value 1, size: nSamples * dim. + * \param in2_mat, input value 2, size: n2 * dim (n2 == 1 or n2 == nSamples). + * \param scale, default 1.0 + * + */ +template <> +void CosSimForward(CpuMatrix& out_mat, + const CpuMatrix& in1_mat, + const CpuMatrix& in2_mat, + real scale) { + CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData()); + size_t num_samples = out_mat.getHeight(); + size_t dim = in1_mat.getWidth(); + /// column vector [nSamples, 1] + real* out = out_mat.getData(); + const real* x = in1_mat.getData(); + const real* y = in2_mat.getData(); + + /// in2 might only have one row or full rows + CHECK(in2_mat.getHeight() == 1LU || in2_mat.getHeight() == num_samples); + size_t inc = (in2_mat.getHeight() == 1LU) ? 0 : dim; + for (size_t i = 0; i < num_samples; ++i, x += dim, y += inc) { + real square_sum_x = 0; + real square_sum_y = 0; + real xy = 0; + for (size_t j = 0; j < dim; ++j) { + square_sum_x += x[j] * x[j]; + square_sum_y += y[j] * y[j]; + xy += x[j] * y[j]; + } + CHECK(square_sum_x > 0 && square_sum_y > 0); + out[i] = scale * xy / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y)); + } +} + +/** + * Cosine Similarity + * for each row i, + * out[i] = scale * cos(input1[i], input2[i]) + * = scale * /sqrt(|input1[i]|^2 * |input2[i]|^2) + * when input2 only has one row, then for each row i, + * out[i] = cos(input1[i], input2[0]) + * + * \param inputs[0] input matrix 1, size: nSamples * dim. + * \param inputs[1] input matrix 2, size: n2 * dim (n2 == 1 or n2 == nSamples). + * \param outputs[0] output matrix, size : nSamples * 1. + */ + +template +class CosSimForwardFunc : public FunctionBase { + void init(const FuncConfig& config) override { + scale_ = config.get("scale"); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(inputs.size(), 2UL); + CHECK_EQ(outputs.size(), 1UL); + + CHECK_EQ(inputs[0].shape().ndims(), 2UL); + CHECK_EQ(inputs[1].shape().ndims(), 2UL); + CHECK_EQ(outputs[0].shape().ndims(), 2UL); + + CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]); + CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]); + CHECK_EQ(outputs[0].shape()[1], 1UL); + + CHECK(outputs[0].data() && inputs[0].data() && inputs[1].data()); + + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + auto out_mat = outputs[0].matrix(); + const auto in1_mat = inputs[0].matrix(); + const auto in2_mat = inputs[1].matrix(); + + CosSimForward(out_mat, in1_mat, in2_mat, scale_); + } + +private: + real scale_; +}; + +/** + * Cosine Similarity Derivative for CpuMatrix + * + * \param in1_grad forward input grad 1, size: nSamples * dim. + * \param in2_grad forward input grad 2, + * size: n2 * dim (n2 == 1 or n2 == nSamples). + * + * \param out_grad backward loss output grad, size : nSamples * 1. + * \param out_val forward output value, size: nSamples * 1. + * \param in1_val forward input value 1, size: nSamples * dim. + * \param in2_val forward input value 2, + * size: n2 * dim (n2 == 1 or n2 == nSamples). + * \param scale, default 1.0 + */ +template <> +void CosSimBackward(const CpuMatrix& out_grad, + const CpuMatrix& out_val, + const CpuMatrix& in1_val, + const CpuMatrix& in2_val, + CpuMatrix& in1_grad, + CpuMatrix& in2_grad, + real scale) { + CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() && + in2_val.getData() && in1_grad.getData() && in2_grad.getData()); + CHECK_EQ(out_val.useGpu_, false) << "Matrix type are GPU, CPU required"; + + const real* grad = out_grad.getData(); + const real* out = out_val.getData(); + const real* prev_out_x = in1_val.getData(); + const real* prev_out_y = in2_val.getData(); + real* prev_grad_x = in1_grad.getData(); + real* prev_grad_y = in2_grad.getData(); + + size_t num_samples = out_grad.getHeight(); + size_t dim = in1_val.getWidth(); + CHECK_EQ(in2_val.getHeight(), in2_grad.getHeight()); + CHECK(in2_val.getHeight() == 1LU || in2_val.getHeight() == num_samples); + size_t inc = (in2_val.getHeight() == 1LU) ? 0 : dim; + for (size_t i = 0; i < num_samples; ++i, + prev_out_x += dim, + prev_out_y += inc, + prev_grad_x += dim, + prev_grad_y += inc) { + real square_sum_x = 0; + real square_sum_y = 0; + real xy = 0; + for (size_t j = 0; j < dim; ++j) { + square_sum_x += prev_out_x[j] * prev_out_x[j]; + square_sum_y += prev_out_y[j] * prev_out_y[j]; + xy += prev_out_x[j] * prev_out_y[j]; + } + CHECK(square_sum_x > 0 && square_sum_y > 0); + if (xy == 0) { + real reciprocal = + 1.0f / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y)); + for (size_t j = 0; j < dim; ++j) { + prev_grad_x[j] += scale * grad[i] * prev_out_y[j] * reciprocal; + prev_grad_y[j] += scale * grad[i] * prev_out_x[j] * reciprocal; + } + } else { + real reciprocal_xy = 1.0f / xy; + real reciprocal_square_sum_x = 1.0f / square_sum_x; + real reciprocal_square_sum_y = 1.0f / square_sum_y; + for (size_t j = 0; j < dim; ++j) { + prev_grad_x[j] += + out[i] * grad[i] * (prev_out_y[j] * reciprocal_xy - + prev_out_x[j] * reciprocal_square_sum_x); + prev_grad_y[j] += + out[i] * grad[i] * (prev_out_x[j] * reciprocal_xy - + prev_out_y[j] * reciprocal_square_sum_y); + } + } + } +} + +/** + * Cosine Similarity backward Derivative + * + * \param outputs[0] forward input grad 1, size: nSamples * dim. + * \param outputs[1] forward input grad 2, + * size: n2 * dim (n2 == 1 or n2 == nSamples). + * + * \param inputs[0] backward loss output grad, size : nSamples * 1. + * \param inputs[1] forward output value, size: nSamples * 1. + * \param inputs[2] forward input value 1, size: nSamples * dim. + * \param inputs[3] forward input value 2, + * size: n2 * dim (n2 == 1 or n2 == nSamples). + */ +template +class CosSimBackwardFunc : public FunctionBase { + void init(const FuncConfig& config) override { + scale_ = config.get("scale"); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(inputs.size(), 4UL); + CHECK_EQ(outputs.size(), 2UL); + /// dim of out_grad and out_val == 1, column vector + CHECK_EQ(inputs[0].shape()[1], 1UL); + CHECK_EQ(inputs[1].shape()[1], 1UL); + /// nSamples of out_grad == out_val == in_val1 == in_grad1 + CHECK_EQ(inputs[1].shape()[0], inputs[0].shape()[0]); + CHECK_EQ(inputs[0].shape()[0], inputs[0].shape()[0]); + CHECK_EQ(outputs[0].shape()[0], inputs[0].shape()[0]); + /// dim of in1_val1 == in_val2 == in_grad1 == in_grad2 + CHECK_EQ(inputs[3].shape()[1], inputs[2].shape()[1]); + CHECK_EQ(outputs[0].shape()[1], inputs[2].shape()[1]); + CHECK_EQ(outputs[1].shape()[1], inputs[2].shape()[1]); + + CHECK(inputs[0].data() && inputs[1].data() && inputs[2].data() && + inputs[3].data() && outputs[0].data() && outputs[1].data()); + + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + CHECK_EQ(outputs[1].getArgType(), ADD_TO); + + const auto out_grad = inputs[0].matrix(); + const auto out_val = inputs[1].matrix(); + const auto in1_val = inputs[2].matrix(); + const auto in2_val = inputs[3].matrix(); + auto in1_grad = outputs[0].matrix(); + auto in2_grad = outputs[1].matrix(); + + CosSimBackward( + out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_); + } + +private: + real scale_; +}; + +REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc); +REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc); +REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc); +#endif +} // namespace paddle diff --git a/paddle/function/CosSimOp.h b/paddle/function/CosSimOp.h new file mode 100644 index 0000000000000000000000000000000000000000..be73064e6375bf1e6c6a7ca6de52e9b9b755880b --- /dev/null +++ b/paddle/function/CosSimOp.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief Cosine Similarity Forward. + * for each row i, + * out[i] = scale * cos(in1[i], in2[i]) + * = scale * \sum_j (in1[i][j] * in2[i][j]) / + * sqrt(sum_j (in1[i][j]^2) * sum_j (in2[i][j])^2) + * + * \param[out] output output value. + * \param[in] intput1 input value. + * \param[in] intput2 input value. + * \param[in] scale default 1.0. + * + */ +template +void CosSimForward(typename Tensor::Matrix& output, + const typename Tensor::Matrix& input1, + const typename Tensor::Matrix& input2, + real scale); + +/** + * \brief Cosine Similarity BackWard for Derivative. + * + * \param[in] output grad backward loss output grad. + * \param[in] output val forward-output value. + * \param[in] input val1 forward input value 1. + * \param[in] input val2 forward input value 2. + * \param[in/out] input grad forward input grad 1. + * \param[in/out] input grad forward input grad 2. + * \param[in] scale default 1.0. + * + */ +template +void CosSimBackward(const typename Tensor::Matrix& out_grad, + const typename Tensor::Matrix& out_value, + const typename Tensor::Matrix& in1_value, + const typename Tensor::Matrix& in2_value, + typename Tensor::Matrix& in1_grad, + typename Tensor::Matrix& in2_grad, + real scale); + +} // namespace paddle diff --git a/paddle/function/CosSimOpGpu.cu b/paddle/function/CosSimOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..1dd733674fa0542c76070955ec63e008b083c7d2 --- /dev/null +++ b/paddle/function/CosSimOpGpu.cu @@ -0,0 +1,241 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "hl_device_functions.cuh" +#include "CosSimOp.h" + +namespace paddle { + +template +__global__ void KeCosSim(real* output, + const real* input1, + const real* input2, + int width, + int input1_height, + int input2_height, + real scale) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + __shared__ real xx[block_size]; + __shared__ real yy[block_size]; + __shared__ real xy[block_size]; + + xx[tid] = 0.0; + yy[tid] = 0.0; + xy[tid] = 0.0; + __syncthreads(); + + input1 += ty * width; + if (input2_height > 1) { + input2 += ty * width; + } + for (int index = tid; index < width; index += block_size) { + real x = input1[index]; + real y = input2[index]; + xx[tid] += x * x; + yy[tid] += y * y; + xy[tid] += x * y; + } + __syncthreads(); + + for (int s = block_size / 2; s > 0; s >>= 1) { + if (tid < s) { + xx[tid] += xx[tid + s]; + yy[tid] += yy[tid + s]; + xy[tid] += xy[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0])); + } +} + +void hlCossim(real* output, + const real* input1, + const real* input2, + size_t width, + size_t input1_height, + size_t input2_height, + real scale) { + CHECK_NOTNULL(output); + CHECK_NOTNULL(input1); + CHECK_NOTNULL(input2); + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, input1_height); + + KeCosSim<<>> + (output, input1, input2, width, input1_height, input2_height, scale); + CHECK_SYNC("hlCossim failed"); +} + +template <> +void CosSimForward(GpuMatrix& out_mat, + const GpuMatrix& in1_mat, + const GpuMatrix& in2_mat, + real scale) { + CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData()); + CHECK(in1_mat.useGpu_ == true && in2_mat.useGpu_ == true) + << "Matrix type are not GPU"; + + size_t num_samples = out_mat.getHeight(); + size_t dim = in1_mat.getWidth(); + real* out = out_mat.getData(); + const real* x = in1_mat.getData(); + const real* y = in2_mat.getData(); + hlCossim(out, x, y, dim, in1_mat.getHeight(), in2_mat.getHeight(), scale); +} + +template +__global__ void KeCosSimDerivative(const real* grad, + const real* output, + const real* prev_out_x, + const real* prev_out_y, + real* prev_grad_x, + real* prev_grad_y, + size_t width, + size_t input1_height, + size_t input2_height, + real scale) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + __shared__ real xx[block_size]; + __shared__ real yy[block_size]; + __shared__ real xy[block_size]; + + xx[tid] = 0.0; + yy[tid] = 0.0; + xy[tid] = 0.0; + __syncthreads(); + + prev_out_x += ty * width; + prev_grad_x += ty * width; + if (input2_height > 1) { + prev_out_y += ty * width; + prev_grad_y += ty * width; + } + for (int index = tid; index < width; index += block_size) { + real x = prev_out_x[index]; + real y = prev_out_y[index]; + xx[tid] += x * x; + yy[tid] += y * y; + xy[tid] += x * y; + } + __syncthreads(); + + for (int s = block_size / 2; s > 0; s >>= 1) { + if (tid < s) { + xx[tid] += xx[tid + s]; + yy[tid] += yy[tid + s]; + xy[tid] += xy[tid + s]; + } + __syncthreads(); + } + if (xy[0] == 0) { + real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0])); + for (int index = tid; index < width; index += block_size) { + prev_grad_x[index] += + scale * grad[ty] * prev_out_y[index] * reciprocal; + if (input2_height > 1) { + prev_grad_y[index] += + scale * grad[ty] * prev_out_x[index] * reciprocal; + } else { + paddle::paddleAtomicAdd(prev_grad_y + index, + scale * grad[ty] * prev_out_x[index] * reciprocal); + } + } + } else { + real reciprocalXY = 1.0 / xy[0]; + real reciprocalSquareSumX = 1.0 / xx[0]; + real reciprocalSquareSumY = 1.0 / yy[0]; + for (int index = tid; index < width; index += block_size) { + prev_grad_x[index] += output[ty] * grad[ty] * + (prev_out_y[index] * reciprocalXY - + prev_out_x[index] * reciprocalSquareSumX); + if (input2_height > 1) { + prev_grad_y[index] += output[ty] * grad[ty] * + (prev_out_x[index] * reciprocalXY - + prev_out_y[index] * reciprocalSquareSumY); + } else { + paddle::paddleAtomicAdd(prev_grad_y + index, output[ty] * grad[ty] * + (prev_out_x[index] * reciprocalXY - + prev_out_y[index] * reciprocalSquareSumY)); + } + } + } +} + +void hlCossimDerivative(const real* grad, + const real* output, + const real* prev_out_x, + const real* prev_out_y, + real* prev_grad_x, + real* prev_grad_y, + size_t width, + size_t input1_height, + size_t input2_height, + real scale) { + CHECK_NOTNULL(grad); + CHECK_NOTNULL(output); + CHECK_NOTNULL(prev_out_x); + CHECK_NOTNULL(prev_out_y); + CHECK_NOTNULL(prev_grad_x); + CHECK_NOTNULL(prev_grad_y); + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, input1_height); + KeCosSimDerivative<<>> + (grad, output, prev_out_x, prev_out_y, prev_grad_x, prev_grad_y, width, + input1_height, input2_height, scale); + CHECK_SYNC("hlCossimDerivate failed"); +} + +template <> +void CosSimBackward(const GpuMatrix& out_grad, + const GpuMatrix& out_val, + const GpuMatrix& in1_val, + const GpuMatrix& in2_val, + GpuMatrix& in1_grad, + GpuMatrix& in2_grad, + real scale) { + CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() && + in2_val.getData() && in1_grad.getData() && in2_grad.getData()); + CHECK(out_grad.useGpu_ && out_val.useGpu_ && in1_val.useGpu_ + && in2_val.useGpu_ && in1_grad.useGpu_ && in2_grad.useGpu_) + << "Matrix types are not equally GPU"; + + size_t dim = in1_val.getWidth(); + const real* grad = out_grad.getData(); + const real* out = out_val.getData(); + const real* prev_out_x = in1_val.getData(); + const real* prev_out_y = in2_val.getData(); + real* prev_grad_x = in1_grad.getData(); + real* prev_grad_y = in2_grad.getData(); + hlCossimDerivative(grad, + out, + prev_out_x, + prev_out_y, + prev_grad_x, + prev_grad_y, + dim, + in1_val.getHeight(), + in2_val.getHeight(), + scale); +} + +} // namespace paddle diff --git a/paddle/function/CosSimOpTest.cpp b/paddle/function/CosSimOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48c815f027161b48c17ce654ab819156fd856199 --- /dev/null +++ b/paddle/function/CosSimOpTest.cpp @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" +#include "paddle/math/Matrix.h" + +using namespace paddle; // NOLINT + +void testCosSimForward(size_t height_x, + size_t height_y, + size_t width, + real scale) { + FunctionCompare test("CosSimForward", FuncConfig().set("scale", scale)); + // prepare input arguments + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width})); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}), + ASSIGN_TO); + // run Function + test.run(); +} + +void testCosSimBackward(size_t height_x, + size_t height_y, + size_t width, + real scale) { + FunctionCompare test("CosSimBackward", FuncConfig().set("scale", scale)); + // prepare input arguments + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width})); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}), + ADD_TO); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}), + ADD_TO); + // run Function + test.run(); +} + +TEST(Matrix, cosSim) { + for (auto height_x : {10, 100, 1000}) { + for (auto height_y : {1, height_x}) { + for (auto width : {10, 100, 1000}) { + for (auto scale : {1.0, 2.0}) { + testCosSimForward(height_x, height_y, width, scale); + testCosSimBackward(height_x, height_y, width, scale); + } + } + } + } +} diff --git a/paddle/gserver/layers/CosSimLayer.cpp b/paddle/gserver/layers/CosSimLayer.cpp index 254120443dc3d41bf2422be2e88cb376d70c93d4..a6c0300acf6752a3536e7939577b561fd97d1eb8 100644 --- a/paddle/gserver/layers/CosSimLayer.cpp +++ b/paddle/gserver/layers/CosSimLayer.cpp @@ -26,15 +26,23 @@ bool CosSimLayer::init(const LayerMap& layerMap, Layer::init(layerMap, parameterMap); CHECK_EQ(inputLayers_.size(), 2LU); + + createFunction(forward_, + "CosSimForward", + FuncConfig().set("scale", (real)config_.cos_scale())); + createFunction(backward_, + "CosSimBackward", + FuncConfig().set("scale", (real)config_.cos_scale())); + return true; } void CosSimLayer::forward(PassType passType) { Layer::forward(passType); - /* malloc memory for the output_ if necessary */ int batchSize = getInputValue(0)->getHeight(); int size = getSize(); + CHECK_EQ(forward_.size(), 1) << "Only one forward function needed"; { REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str()); @@ -42,26 +50,43 @@ void CosSimLayer::forward(PassType passType) { } MatrixPtr outV = getOutputValue(); - /* activation */ { REGISTER_TIMER_INFO("CosFwAtvTimer", getName().c_str()); MatrixPtr prevOut1 = getInputValue(0); MatrixPtr prevOut2 = getInputValue(1); - outV->cosSim(*prevOut1, *prevOut2, config_.cos_scale()); + + CHECK(outV && prevOut1 && prevOut2); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*prevOut1); + inputs.addArg(*prevOut2); + outputs.addArg(*outV, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); } } void CosSimLayer::backward(const UpdateCallback& callback) { /* activation */ { REGISTER_TIMER_INFO("CosBpAtvTimer", getName().c_str()); - MatrixPtr outG = this->getOutputGrad(); - - outG->cosSimDerivative(*this->getOutputValue(), - *getInputValue(0), - *getInputValue(1), - *getInputGrad(0), - *getInputGrad(1), - config_.cos_scale()); + CHECK_EQ(backward_.size(), 1) << "Only one backward function needed"; + + const auto outG = this->getOutputGrad(); + const auto outV = this->getOutputValue(); + const auto inV1 = this->getInputValue(0); + const auto inV2 = this->getInputValue(1); + auto inG1 = this->getInputGrad(0); + auto inG2 = this->getInputGrad(1); + CHECK(outG && outV && inV1 && inV2 && inG1 && inG2); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*outG); + inputs.addArg(*outV); + inputs.addArg(*inV1); + inputs.addArg(*inV2); + outputs.addArg(*inG1, ADD_TO); + outputs.addArg(*inG2, ADD_TO); + + backward_[0]->calc(inputs, outputs); } } diff --git a/paddle/gserver/layers/CosSimLayer.h b/paddle/gserver/layers/CosSimLayer.h index 65549626098f084c5e1786885e06c1bdfa3ba74c..8afaee62c2dcacba006846df0111fcbe8f7575e4 100644 --- a/paddle/gserver/layers/CosSimLayer.h +++ b/paddle/gserver/layers/CosSimLayer.h @@ -28,7 +28,7 @@ namespace paddle { * * - Input1: A vector (batchSize * dataDim) * * - Input2: A vector (batchSize * dataDim) or (1 * dataDim) * - * - Output: A vector (dataDim * 1) + * - Output: A vector (batchSize * 1) * * The config file api is cos_sim. */ diff --git a/paddle/gserver/layers/CosSimVecMatLayer.cpp b/paddle/gserver/layers/CosSimVecMatLayer.cpp index 5f652319e5620227fca166a8f72e5aed416bf5dd..aabafd473aa1e06a767d48d4c49b7b8662e992e7 100644 --- a/paddle/gserver/layers/CosSimVecMatLayer.cpp +++ b/paddle/gserver/layers/CosSimVecMatLayer.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/utils/Stat.h" namespace paddle { - /** * @brief A layer for computing cosine similarity between a vector * and each row of a matrix @@ -98,11 +97,22 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap, dataDim, /* trans= */ false, useGpu_); + + CHECK(tmpRow0 && tmpRow1 && tmpRow2 && tmpRow3 && tmpMtx0 && tmpMtx1); + + createFunction(forward_, + "CosSimForward", + FuncConfig().set("scale", (real)config_.cos_scale())); + createFunction(backward_, + "CosSimBackward", + FuncConfig().set("scale", (real)config_.cos_scale())); + return true; } void CosSimVecMatLayer::forward(PassType passType) { Layer::forward(passType); + CHECK_EQ(forward_.size(), 1) << "Only one forward function needed"; MatrixPtr inV0 = getInputValue(0); MatrixPtr inV1 = getInputValue(1); @@ -118,17 +128,25 @@ void CosSimVecMatLayer::forward(PassType passType) { } MatrixPtr outV = getOutputValue(); - + CHECK(outV && inV0 && inV1); REGISTER_TIMER_INFO("FwCosVMTimer", getName().c_str()); for (size_t i = 0; i < batchSize; i++) { tmpRow0->setData(inV0->rowBuf(i)); tmpMtx0->setData(inV1->rowBuf(i)); tmpRow2->setData(outV->rowBuf(i)); - tmpRow2->cosSim(*(tmpMtx0), *(tmpRow0), config_.cos_scale()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*tmpMtx0); + inputs.addArg(*tmpRow0); + outputs.addArg(*tmpRow2, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); } } void CosSimVecMatLayer::backward(const UpdateCallback& callback) { + CHECK_EQ(backward_.size(), 1) << "Only one forward function needed"; + MatrixPtr inV0 = getInputValue(0); MatrixPtr inV1 = getInputValue(1); MatrixPtr inG0 = getInputGrad(0); @@ -137,27 +155,27 @@ void CosSimVecMatLayer::backward(const UpdateCallback& callback) { MatrixPtr outG = getOutputGrad(); size_t batchSize = inV0->getHeight(); - + CHECK(inV0 && inV1 && inG0 && inG1 && outV && outG); REGISTER_TIMER_INFO("BwCosVMTimer", getName().c_str()); - if (inG0 && inG1) { - for (size_t i = 0; i < batchSize; i++) { - tmpRow0->setData(inV0->rowBuf(i)); - tmpRow1->setData(inG0->rowBuf(i)); - tmpMtx0->setData(inV1->rowBuf(i)); - tmpMtx1->setData(inG1->rowBuf(i)); - tmpRow2->setData(outV->rowBuf(i)); - tmpRow3->setData(outG->rowBuf(i)); - - tmpRow3->cosSimDerivative(*(tmpRow2), - *(tmpMtx0), - *(tmpRow0), - *(tmpMtx1), - *(tmpRow1), - config_.cos_scale()); - } - } else { - CHECK(!inG0 || !inG1) << "Not supported"; + for (size_t i = 0; i < batchSize; i++) { + tmpRow0->setData(inV0->rowBuf(i)); + tmpRow1->setData(inG0->rowBuf(i)); + tmpMtx0->setData(inV1->rowBuf(i)); + tmpMtx1->setData(inG1->rowBuf(i)); + tmpRow2->setData(outV->rowBuf(i)); + tmpRow3->setData(outG->rowBuf(i)); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*tmpRow3); + inputs.addArg(*tmpRow2); + inputs.addArg(*tmpMtx0); + inputs.addArg(*tmpRow0); + outputs.addArg(*tmpMtx1, ADD_TO); + outputs.addArg(*tmpRow1, ADD_TO); + + backward_[0]->calc(inputs, outputs); } } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index a8b53e2105b053399e62fba5321fd22c1fe4a50d..1964b2f8bfaebc49fe3073e03c949a8a9c3e385a 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -941,59 +941,6 @@ void GpuMatrix::softreluDerivative(Matrix& output) { void GpuMatrix::scaledTanh(Matrix& output, real p1, real p2) { BaseMatrix::scaledTanh(output, p1, p2); } -void GpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) { - CHECK(output1.useGpu_ == true && output2.useGpu_ == true) - << "Matrix type are not equal"; - size_t numSamples = getHeight(); - size_t dim = output1.getWidth(); - CHECK_EQ(getWidth(), 1UL); - CHECK_EQ(output1.getHeight(), numSamples); - CHECK_EQ(output1.getWidth(), output2.getWidth()); - real* out = getData(); - real* x = output1.getData(); - real* y = output2.getData(); - hl_cossim(out, x, y, dim, output1.getHeight(), output2.getHeight(), scale); -} -void GpuMatrix::cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale) { - CHECK(output.useGpu_ == true && prevOut1.useGpu_ == true && - prevOut2.useGpu_ == true && prevGrad1.useGpu_ == true && - prevGrad2.useGpu_ == true) - << "Matrix type are not equal"; - CHECK_EQ(getWidth(), 1UL); - CHECK_EQ(output.getWidth(), 1UL); - - size_t numSamples = getHeight(); - CHECK_EQ(output.getHeight(), numSamples); - CHECK_EQ(prevOut1.getHeight(), numSamples); - CHECK_EQ(prevGrad1.getHeight(), numSamples); - - size_t dim = prevOut1.getWidth(); - CHECK_EQ(prevOut2.getWidth(), dim); - CHECK_EQ(prevGrad1.getWidth(), dim); - CHECK_EQ(prevGrad2.getWidth(), dim); - - real* grad = getData(); - real* out = output.getData(); - real* prevOutX = prevOut1.getData(); - real* prevOutY = prevOut2.getData(); - real* prevGradX = prevGrad1.getData(); - real* prevGradY = prevGrad2.getData(); - hl_cossim_derivative(grad, - out, - prevOutX, - prevOutY, - prevGradX, - prevGradY, - dim, - prevOut1.getHeight(), - prevOut2.getHeight(), - scale); -} void GpuMatrix::randomizeUniform() { CHECK(isContiguous()); @@ -3470,105 +3417,6 @@ void CpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) { } } -void CpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) { - size_t numSamples = getHeight(); - size_t dim = output1.getWidth(); - CHECK_EQ(getWidth(), 1UL); - CHECK_EQ(output1.getHeight(), numSamples); - CHECK_EQ(output1.getWidth(), output2.getWidth()); - - real* out = getData(); - const real* x = output1.getData(); - const real* y = output2.getData(); - size_t yInc = dim; - if (output2.getHeight() == 1LU) { - yInc = 0; - } else { - CHECK_EQ(output2.getHeight(), numSamples); - } - for (size_t i = 0; i < numSamples; ++i, x += dim, y += yInc) { - real squareSumX = 0; - real squareSumY = 0; - real xy = 0; - for (size_t j = 0; j < dim; ++j) { - squareSumX += _square(x[j]); - squareSumY += _square(y[j]); - xy += x[j] * y[j]; - } - CHECK(squareSumX > 0 && squareSumY > 0); - out[i] = scale * xy / (std::sqrt(squareSumX) * std::sqrt(squareSumY)); - } -} - -void CpuMatrix::cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale) { - CHECK(output.useGpu_ == false) << "Matrix type are not equal"; - - CHECK_EQ(getWidth(), 1UL); - CHECK_EQ(output.getWidth(), 1UL); - - size_t numSamples = getHeight(); - CHECK_EQ(output.getHeight(), numSamples); - CHECK_EQ(prevOut1.getHeight(), numSamples); - CHECK_EQ(prevGrad1.getHeight(), numSamples); - - size_t dim = prevOut1.getWidth(); - CHECK_EQ(prevOut2.getWidth(), dim); - CHECK_EQ(prevGrad1.getWidth(), dim); - CHECK_EQ(prevGrad2.getWidth(), dim); - - const real* grad = getData(); - const real* out = output.getData(); - const real* prevOutX = prevOut1.getData(); - const real* prevOutY = prevOut2.getData(); - real* prevGradX = prevGrad1.getData(); - real* prevGradY = prevGrad2.getData(); - size_t yInc = dim; - if (prevOut2.getHeight() == 1LU) { - yInc = 0; - CHECK_EQ(prevGrad2.getHeight(), 1LU); - } else { - CHECK_EQ(prevOut2.getHeight(), numSamples); - CHECK_EQ(prevGrad2.getHeight(), numSamples); - } - for (size_t i = 0; i < numSamples; ++i, - prevOutX += dim, - prevOutY += yInc, - prevGradX += dim, - prevGradY += yInc) { - real squareSumX = 0; - real squareSumY = 0; - real xy = 0; - for (size_t j = 0; j < dim; ++j) { - squareSumX += _square(prevOutX[j]); - squareSumY += _square(prevOutY[j]); - xy += prevOutX[j] * prevOutY[j]; - } - CHECK(squareSumX > 0 && squareSumY > 0); - if (xy == 0) { - real reciprocal = 1.0f / (std::sqrt(squareSumX) * std::sqrt(squareSumY)); - for (size_t j = 0; j < dim; ++j) { - prevGradX[j] += scale * grad[i] * prevOutY[j] * reciprocal; - prevGradY[j] += scale * grad[i] * prevOutX[j] * reciprocal; - } - } else { - real reciprocalXY = 1.0f / xy; - real reciprocalSquareSumX = 1.0f / squareSumX; - real reciprocalSquareSumY = 1.0f / squareSumY; - for (size_t j = 0; j < dim; ++j) { - prevGradX[j] += out[i] * grad[i] * (prevOutY[j] * reciprocalXY - - prevOutX[j] * reciprocalSquareSumX); - prevGradY[j] += out[i] * grad[i] * (prevOutX[j] * reciprocalXY - - prevOutY[j] * reciprocalSquareSumY); - } - } - } -} - void CpuMatrix::sumOfSquares(Matrix& output, Matrix& label) { CHECK(output.useGpu_ == false && label.useGpu_ == false) << "Matrix type are not equal"; diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index c92c0a272d5a72868bd61035d77aa4ed0fad7a7c..ea4bbb86b057b526c5ea294b2cd835aef65de58d 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -799,26 +799,6 @@ public: LOG(FATAL) << "Not implemented"; } - /** - * cosine similarity, for each row i, - * this[i] = cos(output1[i], output2[i]) - * - * output2 can only have one row, then for each row i, - * this[i] = cos(output1[i], output2[0]) - */ - virtual void cosSim(Matrix& output1, Matrix& output2, real scale = 1.0f) { - LOG(FATAL) << "Not implemented"; - } - - virtual void cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale = 1.0f) { - LOG(FATAL) << "Not implemented"; - } - /// print out the values of elements to os virtual void print(std::ostream& os) const { LOG(FATAL) << "Not implemented"; @@ -1324,14 +1304,6 @@ public: void softreluDerivative(Matrix& output); void scaledTanh(Matrix& output, real p1, real p2); - void cosSim(Matrix& output1, Matrix& output2, real scale); - void cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale); - virtual void print(std::ostream& os) const; virtual void print(std::ostream& os, size_t height, size_t width) const; @@ -1752,14 +1724,6 @@ public: void softreluDerivative(Matrix& output); void scaledTanh(Matrix& output, real p1, real p2); - void cosSim(Matrix& output1, Matrix& output2, real scale); - void cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale); - void print(std::ostream& os) const; void print(std::ostream& os, size_t height, size_t width) const; void printOneRow(std::ostream& os, size_t idx) const; diff --git a/paddle/math/tests/test_Matrix.cpp b/paddle/math/tests/test_Matrix.cpp index a4084bdf7c6953651bfd9714fd8a5c930f774fe6..1c21da5b76e95603258a5006d0c57b00126e65b9 100644 --- a/paddle/math/tests/test_Matrix.cpp +++ b/paddle/math/tests/test_Matrix.cpp @@ -181,28 +181,6 @@ TEST(Matrix, copyByRowIndex) { } } -void testCosSim(int heightX, int heightY, int width, real scale) { - AutoCompare test(heightX, 1); - CpuMatrix arg1(heightX, width); - CpuMatrix arg2(heightY, width); - arg1.randomizeUniform(); - arg2.randomizeUniform(); - arg2.add(-0.5); - test.cmpWithArg(&Matrix::cosSim, arg1, arg2, scale); -} - -TEST(Matrix, cosSim) { - for (auto heightX : {10, 100, 1000}) { - for (auto heightY : {1, heightX}) { - for (auto width : {10, 100, 1000}) { - for (auto scale : {1.0, 2.0}) { - testCosSim(heightX, heightY, width, scale); - } - } - } - } -} - void testParamReluForward(int height, int width, int w_height, int w_width) { AutoCompare test(height, width); CpuMatrix arg1(height, width); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index e024f2cf1b913f56301ac7b3380f0c382818f413..6caaea443c1df756bfeb775154e8a90400cc3211 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -720,61 +720,6 @@ TEST(Matrix, sequenceAvgForward) { } } -void testCosSimDerivate(int heightX, int heightY, int width, real scale) { - MatrixPtr prevOutX = CpuMatrix::create(heightX, width, false, false); - MatrixPtr prevOutY = CpuMatrix::create(heightY, width, false, false); - MatrixPtr grad = CpuMatrix::create(heightX, 1, false, false); - MatrixPtr output = CpuMatrix::create(heightX, 1, false, false); - MatrixPtr prevGradX = CpuMatrix::create(heightX, width, false, false); - MatrixPtr prevGradY = CpuMatrix::create(heightY, width, false, false); - - prevOutX->randomizeUniform(); - prevOutY->randomizeUniform(); - grad->randomizeUniform(); - output->randomizeUniform(); - prevGradX->randomizeUniform(); - prevGradY->randomizeUniform(); - - MatrixPtr prevOutXGpu = GpuMatrix::create(heightX, width, false, true); - MatrixPtr prevOutYGpu = GpuMatrix::create(heightY, width, false, true); - MatrixPtr gradGpu = GpuMatrix::create(heightX, 1, false, true); - MatrixPtr outputGpu = GpuMatrix::create(heightX, 1, false, true); - MatrixPtr prevGradXGpu = GpuMatrix::create(heightX, width, false, true); - MatrixPtr prevGradYGpu = GpuMatrix::create(heightY, width, false, true); - - prevOutXGpu->copyFrom(*prevOutX); - prevOutYGpu->copyFrom(*prevOutY); - gradGpu->copyFrom(*grad); - outputGpu->copyFrom(*output); - prevGradXGpu->copyFrom(*prevGradX); - prevGradYGpu->copyFrom(*prevGradY); - - grad->cosSimDerivative( - *output, *prevOutX, *prevOutY, *prevGradX, *prevGradY, scale); - - gradGpu->cosSimDerivative(*outputGpu, - *prevOutXGpu, - *prevOutYGpu, - *prevGradXGpu, - *prevGradYGpu, - scale); - - TensorCheckErr(*prevGradX, *prevGradXGpu); - TensorCheckErr(*prevGradY, *prevGradYGpu); -} - -TEST(Matrix, cosSimDerivate) { - for (auto heightX : {1, 10, 100}) { - for (auto heightY : {1, heightX}) { - for (auto width : {1, 10, 100}) { - for (auto scale : {1.0, 2.0}) { - testCosSimDerivate(heightX, heightY, width, scale); - } - } - } - } -} - void testParamReluBackwardDiff(int height, int width, int w_height,