提交 8604666e 编写于 作者: T tianbingsz 提交者: GitHub

Merge pull request #1061 from tianbingsz/paddle_func

Cosine Similarity Paddle Function.
......@@ -188,48 +188,6 @@ extern void hl_param_relu_backward_diff(real* grad_o,
int width,
int height,
int partial_sum);
/**
* @brief cos sim forward
*
* @param[out] output output data
* @param[in] input1 input1 data(matrix)
* @param[in] input2 input2 data(matrix or vector)
* @param[in] width matrix width
* @param[in] input1_height input1_height
* @param[in] input2_height input2_height
* @param[in] scale scale factor
*/
extern void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale);
/**
* @brief cos sim derivate
*
* @param[in] grad output grad
* @param[in] output output data
* @param[in] prevOutX input1 data
* @param[in] prevOutY input2 data
* @param[out] prevGradX input1 grad
* @param[out] prevGradY input2 grad
* @param[in] width matrix width
* @param[in] input1_height input1 height
* @param[in] input2_height input2 height
* @param[in] scale scale factor
*/
extern void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale);
/**
* @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel].
......
......@@ -74,25 +74,6 @@ inline void hl_param_relu_backward_diff(real* grad_o,
int height,
int partial_sum) {}
inline void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {}
inline void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {}
inline void hl_matrix_add_shared_bias(real* A_d,
real* B_d,
const int channel,
......
......@@ -584,177 +584,6 @@ void hl_param_relu_backward_diff(real* grad_o,
CHECK_SYNC("hl_param_relu_backward_diff failed");
}
template<int blockSize>
__global__ void KeCosSim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[blockSize];
__shared__ real yy[blockSize];
__shared__ real xy[blockSize];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
input1 += ty * width;
if (input2_height > 1) {
input2 += ty * width;
}
for (int index = tid; index < width; index += blockSize) {
real x = input1[index];
real y = input2[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
}
}
void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(input1);
CHECK_NOTNULL(input2);
const int blockSize = 256;
dim3 threads(blockSize, 1);
dim3 grid(1, input1_height);
KeCosSim<blockSize><<<grid, threads, 0, STREAM_DEFAULT>>>
(output, input1, input2, width, input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim failed");
}
template<int blockSize>
__global__ void KeCosSimDerivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[blockSize];
__shared__ real yy[blockSize];
__shared__ real xy[blockSize];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
prevOutX += ty * width;
prevGradX += ty * width;
if (input2_height > 1) {
prevOutY += ty * width;
prevGradY += ty * width;
}
for (int index = tid; index < width; index += blockSize) {
real x = prevOutX[index];
real y = prevOutY[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (xy[0] == 0) {
real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
for (int index = tid; index < width; index += blockSize) {
prevGradX[index] +=
scale * grad[ty] * prevOutY[index] * reciprocal;
if (input2_height > 1) {
prevGradY[index] +=
scale * grad[ty] * prevOutX[index] * reciprocal;
} else {
paddle::paddleAtomicAdd(prevGradY + index,
scale * grad[ty] * prevOutX[index] * reciprocal);
}
}
} else {
real reciprocalXY = 1.0 / xy[0];
real reciprocalSquareSumX = 1.0 / xx[0];
real reciprocalSquareSumY = 1.0 / yy[0];
for (int index = tid; index < width; index += blockSize) {
prevGradX[index] += output[ty] * grad[ty] *
(prevOutY[index] * reciprocalXY -
prevOutX[index] * reciprocalSquareSumX);
if (input2_height > 1) {
prevGradY[index] += output[ty] * grad[ty] *
(prevOutX[index] * reciprocalXY -
prevOutY[index] * reciprocalSquareSumY);
} else {
paddle::paddleAtomicAdd(prevGradY + index, output[ty] * grad[ty] *
(prevOutX[index] * reciprocalXY -
prevOutY[index] * reciprocalSquareSumY));
}
}
}
}
void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {
CHECK_NOTNULL(grad);
CHECK_NOTNULL(output);
CHECK_NOTNULL(prevOutX);
CHECK_NOTNULL(prevOutY);
CHECK_NOTNULL(prevGradX);
CHECK_NOTNULL(prevGradY);
const int blockSize = 256;
dim3 threads(blockSize, 1);
dim3 grid(1, input1_height);
KeCosSimDerivative<blockSize><<<grid, threads, 0, STREAM_DEFAULT>>>
(grad, output, prevOutX, prevOutY, prevGradX, prevGradY, width,
input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim_derivate failed");
}
__global__ void KeMatrixAddSharedBias(real* A,
real* B,
const int channel,
......
......@@ -27,6 +27,7 @@ if(WITH_TESTING)
add_simple_unittest(ContextProjectionOpTest)
add_simple_unittest(PadOpTest)
add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest)
endif()
endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CosSimOp.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Vector.h"
namespace paddle {
/**
* Cosine Similarity for CpuMatrix
*
* \param out_mat, output value, size: nSamples * 1.
* \param in1_mat, input value 1, size: nSamples * dim.
* \param in2_mat, input value 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param scale, default 1.0
*
*/
template <>
void CosSimForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
const CpuMatrix& in1_mat,
const CpuMatrix& in2_mat,
real scale) {
CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
size_t num_samples = out_mat.getHeight();
size_t dim = in1_mat.getWidth();
/// column vector [nSamples, 1]
real* out = out_mat.getData();
const real* x = in1_mat.getData();
const real* y = in2_mat.getData();
/// in2 might only have one row or full rows
CHECK(in2_mat.getHeight() == 1LU || in2_mat.getHeight() == num_samples);
size_t inc = (in2_mat.getHeight() == 1LU) ? 0 : dim;
for (size_t i = 0; i < num_samples; ++i, x += dim, y += inc) {
real square_sum_x = 0;
real square_sum_y = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
square_sum_x += x[j] * x[j];
square_sum_y += y[j] * y[j];
xy += x[j] * y[j];
}
CHECK(square_sum_x > 0 && square_sum_y > 0);
out[i] = scale * xy / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
}
}
/**
* Cosine Similarity
* for each row i,
* out[i] = scale * cos(input1[i], input2[i])
* = scale * <input1[i], input2[i]>/sqrt(|input1[i]|^2 * |input2[i]|^2)
* when input2 only has one row, then for each row i,
* out[i] = cos(input1[i], input2[0])
*
* \param inputs[0] input matrix 1, size: nSamples * dim.
* \param inputs[1] input matrix 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param outputs[0] output matrix, size : nSamples * 1.
*/
template <DeviceType Device>
class CosSimForwardFunc : public FunctionBase {
void init(const FuncConfig& config) override {
scale_ = config.get<real>("scale");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(inputs.size(), 2UL);
CHECK_EQ(outputs.size(), 1UL);
CHECK_EQ(inputs[0].shape().ndims(), 2UL);
CHECK_EQ(inputs[1].shape().ndims(), 2UL);
CHECK_EQ(outputs[0].shape().ndims(), 2UL);
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
CHECK_EQ(outputs[0].shape()[1], 1UL);
CHECK(outputs[0].data() && inputs[0].data() && inputs[1].data());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
auto out_mat = outputs[0].matrix<Device>();
const auto in1_mat = inputs[0].matrix<Device>();
const auto in2_mat = inputs[1].matrix<Device>();
CosSimForward<Device>(out_mat, in1_mat, in2_mat, scale_);
}
private:
real scale_;
};
/**
* Cosine Similarity Derivative for CpuMatrix
*
* \param in1_grad forward input grad 1, size: nSamples * dim.
* \param in2_grad forward input grad 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*
* \param out_grad backward loss output grad, size : nSamples * 1.
* \param out_val forward output value, size: nSamples * 1.
* \param in1_val forward input value 1, size: nSamples * dim.
* \param in2_val forward input value 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param scale, default 1.0
*/
template <>
void CosSimBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad,
const CpuMatrix& out_val,
const CpuMatrix& in1_val,
const CpuMatrix& in2_val,
CpuMatrix& in1_grad,
CpuMatrix& in2_grad,
real scale) {
CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
in2_val.getData() && in1_grad.getData() && in2_grad.getData());
CHECK_EQ(out_val.useGpu_, false) << "Matrix type are GPU, CPU required";
const real* grad = out_grad.getData();
const real* out = out_val.getData();
const real* prev_out_x = in1_val.getData();
const real* prev_out_y = in2_val.getData();
real* prev_grad_x = in1_grad.getData();
real* prev_grad_y = in2_grad.getData();
size_t num_samples = out_grad.getHeight();
size_t dim = in1_val.getWidth();
CHECK_EQ(in2_val.getHeight(), in2_grad.getHeight());
CHECK(in2_val.getHeight() == 1LU || in2_val.getHeight() == num_samples);
size_t inc = (in2_val.getHeight() == 1LU) ? 0 : dim;
for (size_t i = 0; i < num_samples; ++i,
prev_out_x += dim,
prev_out_y += inc,
prev_grad_x += dim,
prev_grad_y += inc) {
real square_sum_x = 0;
real square_sum_y = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
square_sum_x += prev_out_x[j] * prev_out_x[j];
square_sum_y += prev_out_y[j] * prev_out_y[j];
xy += prev_out_x[j] * prev_out_y[j];
}
CHECK(square_sum_x > 0 && square_sum_y > 0);
if (xy == 0) {
real reciprocal =
1.0f / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] += scale * grad[i] * prev_out_y[j] * reciprocal;
prev_grad_y[j] += scale * grad[i] * prev_out_x[j] * reciprocal;
}
} else {
real reciprocal_xy = 1.0f / xy;
real reciprocal_square_sum_x = 1.0f / square_sum_x;
real reciprocal_square_sum_y = 1.0f / square_sum_y;
for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] +=
out[i] * grad[i] * (prev_out_y[j] * reciprocal_xy -
prev_out_x[j] * reciprocal_square_sum_x);
prev_grad_y[j] +=
out[i] * grad[i] * (prev_out_x[j] * reciprocal_xy -
prev_out_y[j] * reciprocal_square_sum_y);
}
}
}
}
/**
* Cosine Similarity backward Derivative
*
* \param outputs[0] forward input grad 1, size: nSamples * dim.
* \param outputs[1] forward input grad 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*
* \param inputs[0] backward loss output grad, size : nSamples * 1.
* \param inputs[1] forward output value, size: nSamples * 1.
* \param inputs[2] forward input value 1, size: nSamples * dim.
* \param inputs[3] forward input value 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*/
template <DeviceType Device>
class CosSimBackwardFunc : public FunctionBase {
void init(const FuncConfig& config) override {
scale_ = config.get<real>("scale");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(inputs.size(), 4UL);
CHECK_EQ(outputs.size(), 2UL);
/// dim of out_grad and out_val == 1, column vector
CHECK_EQ(inputs[0].shape()[1], 1UL);
CHECK_EQ(inputs[1].shape()[1], 1UL);
/// nSamples of out_grad == out_val == in_val1 == in_grad1
CHECK_EQ(inputs[1].shape()[0], inputs[0].shape()[0]);
CHECK_EQ(inputs[0].shape()[0], inputs[0].shape()[0]);
CHECK_EQ(outputs[0].shape()[0], inputs[0].shape()[0]);
/// dim of in1_val1 == in_val2 == in_grad1 == in_grad2
CHECK_EQ(inputs[3].shape()[1], inputs[2].shape()[1]);
CHECK_EQ(outputs[0].shape()[1], inputs[2].shape()[1]);
CHECK_EQ(outputs[1].shape()[1], inputs[2].shape()[1]);
CHECK(inputs[0].data() && inputs[1].data() && inputs[2].data() &&
inputs[3].data() && outputs[0].data() && outputs[1].data());
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
const auto out_grad = inputs[0].matrix<Device>();
const auto out_val = inputs[1].matrix<Device>();
const auto in1_val = inputs[2].matrix<Device>();
const auto in2_val = inputs[3].matrix<Device>();
auto in1_grad = outputs[0].matrix<Device>();
auto in2_grad = outputs[1].matrix<Device>();
CosSimBackward<Device>(
out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_);
}
private:
real scale_;
};
REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief Cosine Similarity Forward.
* for each row i,
* out[i] = scale * cos(in1[i], in2[i])
* = scale * \sum_j (in1[i][j] * in2[i][j]) /
* sqrt(sum_j (in1[i][j]^2) * sum_j (in2[i][j])^2)
*
* \param[out] output output value.
* \param[in] intput1 input value.
* \param[in] intput2 input value.
* \param[in] scale default 1.0.
*
*/
template <DeviceType Device>
void CosSimForward(typename Tensor<real, Device>::Matrix& output,
const typename Tensor<real, Device>::Matrix& input1,
const typename Tensor<real, Device>::Matrix& input2,
real scale);
/**
* \brief Cosine Similarity BackWard for Derivative.
*
* \param[in] output grad backward loss output grad.
* \param[in] output val forward-output value.
* \param[in] input val1 forward input value 1.
* \param[in] input val2 forward input value 2.
* \param[in/out] input grad forward input grad 1.
* \param[in/out] input grad forward input grad 2.
* \param[in] scale default 1.0.
*
*/
template <DeviceType Device>
void CosSimBackward(const typename Tensor<real, Device>::Matrix& out_grad,
const typename Tensor<real, Device>::Matrix& out_value,
const typename Tensor<real, Device>::Matrix& in1_value,
const typename Tensor<real, Device>::Matrix& in2_value,
typename Tensor<real, Device>::Matrix& in1_grad,
typename Tensor<real, Device>::Matrix& in2_grad,
real scale);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "hl_device_functions.cuh"
#include "CosSimOp.h"
namespace paddle {
template<int block_size>
__global__ void KeCosSim(real* output,
const real* input1,
const real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[block_size];
__shared__ real yy[block_size];
__shared__ real xy[block_size];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
input1 += ty * width;
if (input2_height > 1) {
input2 += ty * width;
}
for (int index = tid; index < width; index += block_size) {
real x = input1[index];
real y = input2[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
}
}
void hlCossim(real* output,
const real* input1,
const real* input2,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(input1);
CHECK_NOTNULL(input2);
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, input1_height);
KeCosSim<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
(output, input1, input2, width, input1_height, input2_height, scale);
CHECK_SYNC("hlCossim failed");
}
template <>
void CosSimForward<DEVICE_TYPE_GPU>(GpuMatrix& out_mat,
const GpuMatrix& in1_mat,
const GpuMatrix& in2_mat,
real scale) {
CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
CHECK(in1_mat.useGpu_ == true && in2_mat.useGpu_ == true)
<< "Matrix type are not GPU";
size_t num_samples = out_mat.getHeight();
size_t dim = in1_mat.getWidth();
real* out = out_mat.getData();
const real* x = in1_mat.getData();
const real* y = in2_mat.getData();
hlCossim(out, x, y, dim, in1_mat.getHeight(), in2_mat.getHeight(), scale);
}
template<int block_size>
__global__ void KeCosSimDerivative(const real* grad,
const real* output,
const real* prev_out_x,
const real* prev_out_y,
real* prev_grad_x,
real* prev_grad_y,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[block_size];
__shared__ real yy[block_size];
__shared__ real xy[block_size];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
prev_out_x += ty * width;
prev_grad_x += ty * width;
if (input2_height > 1) {
prev_out_y += ty * width;
prev_grad_y += ty * width;
}
for (int index = tid; index < width; index += block_size) {
real x = prev_out_x[index];
real y = prev_out_y[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (xy[0] == 0) {
real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
for (int index = tid; index < width; index += block_size) {
prev_grad_x[index] +=
scale * grad[ty] * prev_out_y[index] * reciprocal;
if (input2_height > 1) {
prev_grad_y[index] +=
scale * grad[ty] * prev_out_x[index] * reciprocal;
} else {
paddle::paddleAtomicAdd(prev_grad_y + index,
scale * grad[ty] * prev_out_x[index] * reciprocal);
}
}
} else {
real reciprocalXY = 1.0 / xy[0];
real reciprocalSquareSumX = 1.0 / xx[0];
real reciprocalSquareSumY = 1.0 / yy[0];
for (int index = tid; index < width; index += block_size) {
prev_grad_x[index] += output[ty] * grad[ty] *
(prev_out_y[index] * reciprocalXY -
prev_out_x[index] * reciprocalSquareSumX);
if (input2_height > 1) {
prev_grad_y[index] += output[ty] * grad[ty] *
(prev_out_x[index] * reciprocalXY -
prev_out_y[index] * reciprocalSquareSumY);
} else {
paddle::paddleAtomicAdd(prev_grad_y + index, output[ty] * grad[ty] *
(prev_out_x[index] * reciprocalXY -
prev_out_y[index] * reciprocalSquareSumY));
}
}
}
}
void hlCossimDerivative(const real* grad,
const real* output,
const real* prev_out_x,
const real* prev_out_y,
real* prev_grad_x,
real* prev_grad_y,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
CHECK_NOTNULL(grad);
CHECK_NOTNULL(output);
CHECK_NOTNULL(prev_out_x);
CHECK_NOTNULL(prev_out_y);
CHECK_NOTNULL(prev_grad_x);
CHECK_NOTNULL(prev_grad_y);
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, input1_height);
KeCosSimDerivative<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
(grad, output, prev_out_x, prev_out_y, prev_grad_x, prev_grad_y, width,
input1_height, input2_height, scale);
CHECK_SYNC("hlCossimDerivate failed");
}
template <>
void CosSimBackward<DEVICE_TYPE_GPU>(const GpuMatrix& out_grad,
const GpuMatrix& out_val,
const GpuMatrix& in1_val,
const GpuMatrix& in2_val,
GpuMatrix& in1_grad,
GpuMatrix& in2_grad,
real scale) {
CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
in2_val.getData() && in1_grad.getData() && in2_grad.getData());
CHECK(out_grad.useGpu_ && out_val.useGpu_ && in1_val.useGpu_
&& in2_val.useGpu_ && in1_grad.useGpu_ && in2_grad.useGpu_)
<< "Matrix types are not equally GPU";
size_t dim = in1_val.getWidth();
const real* grad = out_grad.getData();
const real* out = out_val.getData();
const real* prev_out_x = in1_val.getData();
const real* prev_out_y = in2_val.getData();
real* prev_grad_x = in1_grad.getData();
real* prev_grad_y = in2_grad.getData();
hlCossimDerivative(grad,
out,
prev_out_x,
prev_out_y,
prev_grad_x,
prev_grad_y,
dim,
in1_val.getHeight(),
in2_val.getHeight(),
scale);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
#include "paddle/math/Matrix.h"
using namespace paddle; // NOLINT
void testCosSimForward(size_t height_x,
size_t height_y,
size_t width,
real scale) {
FunctionCompare test("CosSimForward", FuncConfig().set("scale", scale));
// prepare input arguments
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}),
ASSIGN_TO);
// run Function
test.run();
}
void testCosSimBackward(size_t height_x,
size_t height_y,
size_t width,
real scale) {
FunctionCompare test("CosSimBackward", FuncConfig().set("scale", scale));
// prepare input arguments
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}),
ADD_TO);
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}),
ADD_TO);
// run Function
test.run();
}
TEST(Matrix, cosSim) {
for (auto height_x : {10, 100, 1000}) {
for (auto height_y : {1, height_x}) {
for (auto width : {10, 100, 1000}) {
for (auto scale : {1.0, 2.0}) {
testCosSimForward(height_x, height_y, width, scale);
testCosSimBackward(height_x, height_y, width, scale);
}
}
}
}
}
......@@ -26,15 +26,23 @@ bool CosSimLayer::init(const LayerMap& layerMap,
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2LU);
createFunction(forward_,
"CosSimForward",
FuncConfig().set("scale", (real)config_.cos_scale()));
createFunction(backward_,
"CosSimBackward",
FuncConfig().set("scale", (real)config_.cos_scale()));
return true;
}
void CosSimLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
int batchSize = getInputValue(0)->getHeight();
int size = getSize();
CHECK_EQ(forward_.size(), 1) << "Only one forward function needed";
{
REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str());
......@@ -42,26 +50,43 @@ void CosSimLayer::forward(PassType passType) {
}
MatrixPtr outV = getOutputValue();
/* activation */ {
REGISTER_TIMER_INFO("CosFwAtvTimer", getName().c_str());
MatrixPtr prevOut1 = getInputValue(0);
MatrixPtr prevOut2 = getInputValue(1);
outV->cosSim(*prevOut1, *prevOut2, config_.cos_scale());
CHECK(outV && prevOut1 && prevOut2);
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*prevOut1);
inputs.addArg(*prevOut2);
outputs.addArg(*outV, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
}
void CosSimLayer::backward(const UpdateCallback& callback) {
/* activation */ {
REGISTER_TIMER_INFO("CosBpAtvTimer", getName().c_str());
MatrixPtr outG = this->getOutputGrad();
outG->cosSimDerivative(*this->getOutputValue(),
*getInputValue(0),
*getInputValue(1),
*getInputGrad(0),
*getInputGrad(1),
config_.cos_scale());
CHECK_EQ(backward_.size(), 1) << "Only one backward function needed";
const auto outG = this->getOutputGrad();
const auto outV = this->getOutputValue();
const auto inV1 = this->getInputValue(0);
const auto inV2 = this->getInputValue(1);
auto inG1 = this->getInputGrad(0);
auto inG2 = this->getInputGrad(1);
CHECK(outG && outV && inV1 && inV2 && inG1 && inG2);
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*outG);
inputs.addArg(*outV);
inputs.addArg(*inV1);
inputs.addArg(*inV2);
outputs.addArg(*inG1, ADD_TO);
outputs.addArg(*inG2, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
}
......
......@@ -28,7 +28,7 @@ namespace paddle {
*
* - Input1: A vector (batchSize * dataDim) *
* - Input2: A vector (batchSize * dataDim) or (1 * dataDim) *
* - Output: A vector (dataDim * 1)
* - Output: A vector (batchSize * 1)
*
* The config file api is cos_sim.
*/
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/utils/Stat.h"
namespace paddle {
/**
* @brief A layer for computing cosine similarity between a vector
* and each row of a matrix
......@@ -98,11 +97,22 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap,
dataDim,
/* trans= */ false,
useGpu_);
CHECK(tmpRow0 && tmpRow1 && tmpRow2 && tmpRow3 && tmpMtx0 && tmpMtx1);
createFunction(forward_,
"CosSimForward",
FuncConfig().set("scale", (real)config_.cos_scale()));
createFunction(backward_,
"CosSimBackward",
FuncConfig().set("scale", (real)config_.cos_scale()));
return true;
}
void CosSimVecMatLayer::forward(PassType passType) {
Layer::forward(passType);
CHECK_EQ(forward_.size(), 1) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
......@@ -118,17 +128,25 @@ void CosSimVecMatLayer::forward(PassType passType) {
}
MatrixPtr outV = getOutputValue();
CHECK(outV && inV0 && inV1);
REGISTER_TIMER_INFO("FwCosVMTimer", getName().c_str());
for (size_t i = 0; i < batchSize; i++) {
tmpRow0->setData(inV0->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i));
tmpRow2->cosSim(*(tmpMtx0), *(tmpRow0), config_.cos_scale());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*tmpMtx0);
inputs.addArg(*tmpRow0);
outputs.addArg(*tmpRow2, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
}
void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
CHECK_EQ(backward_.size(), 1) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr inG0 = getInputGrad(0);
......@@ -137,27 +155,27 @@ void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
MatrixPtr outG = getOutputGrad();
size_t batchSize = inV0->getHeight();
CHECK(inV0 && inV1 && inG0 && inG1 && outV && outG);
REGISTER_TIMER_INFO("BwCosVMTimer", getName().c_str());
if (inG0 && inG1) {
for (size_t i = 0; i < batchSize; i++) {
tmpRow0->setData(inV0->rowBuf(i));
tmpRow1->setData(inG0->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i));
tmpMtx1->setData(inG1->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i));
tmpRow3->setData(outG->rowBuf(i));
tmpRow3->cosSimDerivative(*(tmpRow2),
*(tmpMtx0),
*(tmpRow0),
*(tmpMtx1),
*(tmpRow1),
config_.cos_scale());
}
} else {
CHECK(!inG0 || !inG1) << "Not supported";
for (size_t i = 0; i < batchSize; i++) {
tmpRow0->setData(inV0->rowBuf(i));
tmpRow1->setData(inG0->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i));
tmpMtx1->setData(inG1->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i));
tmpRow3->setData(outG->rowBuf(i));
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*tmpRow3);
inputs.addArg(*tmpRow2);
inputs.addArg(*tmpMtx0);
inputs.addArg(*tmpRow0);
outputs.addArg(*tmpMtx1, ADD_TO);
outputs.addArg(*tmpRow1, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
}
......
......@@ -941,59 +941,6 @@ void GpuMatrix::softreluDerivative(Matrix& output) {
void GpuMatrix::scaledTanh(Matrix& output, real p1, real p2) {
BaseMatrix::scaledTanh(output, p1, p2);
}
void GpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) {
CHECK(output1.useGpu_ == true && output2.useGpu_ == true)
<< "Matrix type are not equal";
size_t numSamples = getHeight();
size_t dim = output1.getWidth();
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output1.getHeight(), numSamples);
CHECK_EQ(output1.getWidth(), output2.getWidth());
real* out = getData();
real* x = output1.getData();
real* y = output2.getData();
hl_cossim(out, x, y, dim, output1.getHeight(), output2.getHeight(), scale);
}
void GpuMatrix::cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale) {
CHECK(output.useGpu_ == true && prevOut1.useGpu_ == true &&
prevOut2.useGpu_ == true && prevGrad1.useGpu_ == true &&
prevGrad2.useGpu_ == true)
<< "Matrix type are not equal";
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output.getWidth(), 1UL);
size_t numSamples = getHeight();
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(prevOut1.getHeight(), numSamples);
CHECK_EQ(prevGrad1.getHeight(), numSamples);
size_t dim = prevOut1.getWidth();
CHECK_EQ(prevOut2.getWidth(), dim);
CHECK_EQ(prevGrad1.getWidth(), dim);
CHECK_EQ(prevGrad2.getWidth(), dim);
real* grad = getData();
real* out = output.getData();
real* prevOutX = prevOut1.getData();
real* prevOutY = prevOut2.getData();
real* prevGradX = prevGrad1.getData();
real* prevGradY = prevGrad2.getData();
hl_cossim_derivative(grad,
out,
prevOutX,
prevOutY,
prevGradX,
prevGradY,
dim,
prevOut1.getHeight(),
prevOut2.getHeight(),
scale);
}
void GpuMatrix::randomizeUniform() {
CHECK(isContiguous());
......@@ -3470,105 +3417,6 @@ void CpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) {
}
}
void CpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) {
size_t numSamples = getHeight();
size_t dim = output1.getWidth();
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output1.getHeight(), numSamples);
CHECK_EQ(output1.getWidth(), output2.getWidth());
real* out = getData();
const real* x = output1.getData();
const real* y = output2.getData();
size_t yInc = dim;
if (output2.getHeight() == 1LU) {
yInc = 0;
} else {
CHECK_EQ(output2.getHeight(), numSamples);
}
for (size_t i = 0; i < numSamples; ++i, x += dim, y += yInc) {
real squareSumX = 0;
real squareSumY = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
squareSumX += _square(x[j]);
squareSumY += _square(y[j]);
xy += x[j] * y[j];
}
CHECK(squareSumX > 0 && squareSumY > 0);
out[i] = scale * xy / (std::sqrt(squareSumX) * std::sqrt(squareSumY));
}
}
void CpuMatrix::cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale) {
CHECK(output.useGpu_ == false) << "Matrix type are not equal";
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output.getWidth(), 1UL);
size_t numSamples = getHeight();
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(prevOut1.getHeight(), numSamples);
CHECK_EQ(prevGrad1.getHeight(), numSamples);
size_t dim = prevOut1.getWidth();
CHECK_EQ(prevOut2.getWidth(), dim);
CHECK_EQ(prevGrad1.getWidth(), dim);
CHECK_EQ(prevGrad2.getWidth(), dim);
const real* grad = getData();
const real* out = output.getData();
const real* prevOutX = prevOut1.getData();
const real* prevOutY = prevOut2.getData();
real* prevGradX = prevGrad1.getData();
real* prevGradY = prevGrad2.getData();
size_t yInc = dim;
if (prevOut2.getHeight() == 1LU) {
yInc = 0;
CHECK_EQ(prevGrad2.getHeight(), 1LU);
} else {
CHECK_EQ(prevOut2.getHeight(), numSamples);
CHECK_EQ(prevGrad2.getHeight(), numSamples);
}
for (size_t i = 0; i < numSamples; ++i,
prevOutX += dim,
prevOutY += yInc,
prevGradX += dim,
prevGradY += yInc) {
real squareSumX = 0;
real squareSumY = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
squareSumX += _square(prevOutX[j]);
squareSumY += _square(prevOutY[j]);
xy += prevOutX[j] * prevOutY[j];
}
CHECK(squareSumX > 0 && squareSumY > 0);
if (xy == 0) {
real reciprocal = 1.0f / (std::sqrt(squareSumX) * std::sqrt(squareSumY));
for (size_t j = 0; j < dim; ++j) {
prevGradX[j] += scale * grad[i] * prevOutY[j] * reciprocal;
prevGradY[j] += scale * grad[i] * prevOutX[j] * reciprocal;
}
} else {
real reciprocalXY = 1.0f / xy;
real reciprocalSquareSumX = 1.0f / squareSumX;
real reciprocalSquareSumY = 1.0f / squareSumY;
for (size_t j = 0; j < dim; ++j) {
prevGradX[j] += out[i] * grad[i] * (prevOutY[j] * reciprocalXY -
prevOutX[j] * reciprocalSquareSumX);
prevGradY[j] += out[i] * grad[i] * (prevOutX[j] * reciprocalXY -
prevOutY[j] * reciprocalSquareSumY);
}
}
}
}
void CpuMatrix::sumOfSquares(Matrix& output, Matrix& label) {
CHECK(output.useGpu_ == false && label.useGpu_ == false)
<< "Matrix type are not equal";
......
......@@ -799,26 +799,6 @@ public:
LOG(FATAL) << "Not implemented";
}
/**
* cosine similarity, for each row i,
* this[i] = cos(output1[i], output2[i])
*
* output2 can only have one row, then for each row i,
* this[i] = cos(output1[i], output2[0])
*/
virtual void cosSim(Matrix& output1, Matrix& output2, real scale = 1.0f) {
LOG(FATAL) << "Not implemented";
}
virtual void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale = 1.0f) {
LOG(FATAL) << "Not implemented";
}
/// print out the values of elements to os
virtual void print(std::ostream& os) const {
LOG(FATAL) << "Not implemented";
......@@ -1324,14 +1304,6 @@ public:
void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2);
void cosSim(Matrix& output1, Matrix& output2, real scale);
void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale);
virtual void print(std::ostream& os) const;
virtual void print(std::ostream& os, size_t height, size_t width) const;
......@@ -1752,14 +1724,6 @@ public:
void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2);
void cosSim(Matrix& output1, Matrix& output2, real scale);
void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale);
void print(std::ostream& os) const;
void print(std::ostream& os, size_t height, size_t width) const;
void printOneRow(std::ostream& os, size_t idx) const;
......
......@@ -181,28 +181,6 @@ TEST(Matrix, copyByRowIndex) {
}
}
void testCosSim(int heightX, int heightY, int width, real scale) {
AutoCompare test(heightX, 1);
CpuMatrix arg1(heightX, width);
CpuMatrix arg2(heightY, width);
arg1.randomizeUniform();
arg2.randomizeUniform();
arg2.add(-0.5);
test.cmpWithArg(&Matrix::cosSim, arg1, arg2, scale);
}
TEST(Matrix, cosSim) {
for (auto heightX : {10, 100, 1000}) {
for (auto heightY : {1, heightX}) {
for (auto width : {10, 100, 1000}) {
for (auto scale : {1.0, 2.0}) {
testCosSim(heightX, heightY, width, scale);
}
}
}
}
}
void testParamReluForward(int height, int width, int w_height, int w_width) {
AutoCompare test(height, width);
CpuMatrix arg1(height, width);
......
......@@ -720,61 +720,6 @@ TEST(Matrix, sequenceAvgForward) {
}
}
void testCosSimDerivate(int heightX, int heightY, int width, real scale) {
MatrixPtr prevOutX = CpuMatrix::create(heightX, width, false, false);
MatrixPtr prevOutY = CpuMatrix::create(heightY, width, false, false);
MatrixPtr grad = CpuMatrix::create(heightX, 1, false, false);
MatrixPtr output = CpuMatrix::create(heightX, 1, false, false);
MatrixPtr prevGradX = CpuMatrix::create(heightX, width, false, false);
MatrixPtr prevGradY = CpuMatrix::create(heightY, width, false, false);
prevOutX->randomizeUniform();
prevOutY->randomizeUniform();
grad->randomizeUniform();
output->randomizeUniform();
prevGradX->randomizeUniform();
prevGradY->randomizeUniform();
MatrixPtr prevOutXGpu = GpuMatrix::create(heightX, width, false, true);
MatrixPtr prevOutYGpu = GpuMatrix::create(heightY, width, false, true);
MatrixPtr gradGpu = GpuMatrix::create(heightX, 1, false, true);
MatrixPtr outputGpu = GpuMatrix::create(heightX, 1, false, true);
MatrixPtr prevGradXGpu = GpuMatrix::create(heightX, width, false, true);
MatrixPtr prevGradYGpu = GpuMatrix::create(heightY, width, false, true);
prevOutXGpu->copyFrom(*prevOutX);
prevOutYGpu->copyFrom(*prevOutY);
gradGpu->copyFrom(*grad);
outputGpu->copyFrom(*output);
prevGradXGpu->copyFrom(*prevGradX);
prevGradYGpu->copyFrom(*prevGradY);
grad->cosSimDerivative(
*output, *prevOutX, *prevOutY, *prevGradX, *prevGradY, scale);
gradGpu->cosSimDerivative(*outputGpu,
*prevOutXGpu,
*prevOutYGpu,
*prevGradXGpu,
*prevGradYGpu,
scale);
TensorCheckErr(*prevGradX, *prevGradXGpu);
TensorCheckErr(*prevGradY, *prevGradYGpu);
}
TEST(Matrix, cosSimDerivate) {
for (auto heightX : {1, 10, 100}) {
for (auto heightY : {1, heightX}) {
for (auto width : {1, 10, 100}) {
for (auto scale : {1.0, 2.0}) {
testCosSimDerivate(heightX, heightY, width, scale);
}
}
}
}
}
void testParamReluBackwardDiff(int height,
int width,
int w_height,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册