提交 281250f5 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into v2-network

......@@ -54,7 +54,9 @@ before_install:
fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
- pip install numpy wheel protobuf sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker
script:
- paddle/scripts/travis/main.sh
notifications:
......
# PaddlePaddle Design Doc
## Ingredients
As our design principle is starting from the essence: how could we
allow users to express and solve their problems at neural networks.
Some essential concepts that our API have to provide include:
1. A *topology* is an expression of *layers*.
1. A layer could be any kind of computation, including *cost*.
1. Some layers have parameters, some don't. Most costs don't have
parameters.
1. In some topologies, layers share parameters. For
example,
[the network for training a ranking model](https://github.com/PaddlePaddle/Paddle/issues/1311#issuecomment-279121850).
1. At programming time, users specify topologies and possible sharing
of parameters. PaddlePaddle can figure out and create parameters
required (and possibly shared) by one or more topologies.
## Starting from Examples
As a summarization
of
[our disucssion](https://github.com/PaddlePaddle/Paddle/issues/1315),
let us present two examples here:
### Example 1. Sharing Parameters between Layers
We use
the
[3-branch ranking](https://github.com/PaddlePaddle/Paddle/issues/1311#issuecomment-279121850) model
in this example. For your convenience, I copy-a-paste the model's
topology as follows:
```
A -> f -\
Q -> f --> cost
B -> f -/
```
The following program trains the topology including the cost, and then
use the sub-network in the trained topology in inference:
```python
def f(in):
e = paddle.layer.embedding(in, parameter_name="embedding")
o = paddle.layer.softmax(e, parameter_name="semantic")
return o
# Create 3 topologies (subnets), they share parameters because all
# correspoinding layers have the same parameter names.
fA = f(paddle.layer.data(input_name="A"))
fB = f(paddle.layer.data(input_name="B"))
fQ = f(paddle.layer.data(input_name="Q"))
topology = paddle.layer.less_than(
paddle.layer.cross_entropy(fA, fQ),
paddle.layer.corss_entropy(fB, fQ))
# Derive parameters required in topology and create them in model.
parameters = paddle.parameters.create(topology)
# Estimate parameters used in topology from data.
paddle.train(topology, parameters, reader=read_ranking_model_data)
# Inference using fA (or fB or fC, as they share their parameters).
[testA, testB, testQ] = read_ranking_model_data()
print "The sematic-vector of testA: ", paddle.infer(fA, parameters, testA)
```
### Example 2. Sharing Parameters between "Models"
We use [GAN](https://github.com/PaddlePaddle/book/tree/develop/gan) in
this example. In the following example program, `d0` and `d1`
correspond to the two networks in the following figure:
<img src="https://github.com/wangyang59/book/raw/00036f4b0da5225041a6824587c1a01cf20159b1/gan/image/gan_ig.png" width=400 />
```python
def G(in):
# over-simplified example as G has only one layers:
return paddle.layer.fc(in, parameter_name="G")
def D(in);
# again, over-simplified:
return paddle.layer.fc(in, parameter_name="D")
# Construct the first topology, which contains both D and G.
# By learning this topology, we update parameters of G.
d0 = paddle.layer.should_be_false(D(G(paddle.layer.data())))
# Construct a second topology d1, which contains only D. By
# training this topology, we update parameters of D. Note
# that d1 share parameters with d0.
d1 = paddle.layer.should_be_true(D(paddle.layer.data()))
# Create parameters from a list of multiple topologies (models) for
# the chance to share parameters between these topologies.
parameters = paddle.parameters.create([d0, d1])
# Iterative training of GAN.
for ...:
train(d0, parameters, reader=read_from_rng, immutable_parameters={"D"})
train(d1, parameters, reader=read_from_realistic_images)
# Use d1 for inference:
print "D thinks a batch of images are realistic ", infer(d1, parameters, read_mnist_images)
```
### Summarization
Above two programs reveal some important design concerns:
1. Users describe a topology as an expression of layers. Every layer
has a *parameter name*. If the users don't specify it explicitly, it's automatically generated as a unique name. By
specifying the parameter name, users can specify the sharing of
parameters between layers and even between topologies.
1. `paddle.parameters.create` figures out parameters required by one
or more topologies from parameter names of layers. It creates these
parameters and returns a `ParameterSet` object, which is in essence
a map from *parameter names* to *parameters*.
1. At training and inference time, `paddle.train` and `paddle.infer`
requires both a topology and the parameter set that holds the parameters of that topology. There are some reasons:
1. This prevents users from forgetting to call
`paddle.parameters.create`.
1. `paddle.train` needs to know which parameter set to update.
1. Users could load another (pre-trained) parameter set and use it
with a topology in `train.infer`.
1. By specifying the `immutable_parameters` parameter of
`paddle.train`, we can forbid the update of these parameters.
## Reader
Not all programming frameworks allow users to define I/O functions.
An example is Google MapReduce, which can only read from text,
SSTable, and RecordIO files. Hadoop MapReduce allows users to define
readers and writers by deriving from base classes `Reader` and
`Writer`. The former is less flexible but also less error-prone. We
decide to provide the flexibility to users to define their readers.
There are some open questions here:
1. **Should a reader return a Python dictionary?**
1. **How to map multiple outputs from a reader to multiple data layers?**
1. **How to easily compose some existing readers to read more data and
feed a topology with more data layers?**
## Training
The recommended way to training a model is to call `paddle.train`,
which simply calls `paddle.trainer.Default`, a global variable of
type `paddle.trainer.SGD`. Equivalently, we can do
```python
opt = paddle.trainer.SGD(..., paddle.updater.Adam(...))
opt.train(topology, parameters, reader=read, ...)
```
### Updater
Please be aware that a trainer can accept an updater as its data
member, where an updater is a class derived from
`paddle.trainer.Updater`. This is to make it easier to customize
trainers, as discussed
[here](https://github.com/PaddlePaddle/Paddle/issues/1319).
### Event Handler
`paddle.train` and `paddle.trainer.XXX.train` take an optional
parameter `event_handler`, which should be either `None` or a function
that handle some events:
1. BeginTraining
1. EndTraining
1. BeginIteration
1. EndIteration
1. BeginPass
1. EndPass
where EndPass is sent if and only if the reader yields
`end_pass=True`.
An example as follows:
```python
def event_handler(event):
if ininstance(event, paddle.event.EndIteration):
print paddle.test(...)
paddle.train(topology, parameters, reader, event_handler)
```
If we are writing a PaddlePaddle program in and for iPython/Jypyter,
we can use metaplotlib in the event handler to plot a curve of
cost/error versus iterations, as shown
[here](https://blog.dominodatalab.com/interactive-dashboards-in-jupyter/).
### Distributed Training
If users want to do distributed training on a cluster, s/he should
call `paddle.dist_train` and provides access tokens to the cluster as
a parameter.
For example, if the user has a TLS certificate that allows him to
access a Kubernetes cluster, s/he should be able to call
```python
paddle.dist_train(model,
trainer=paddle.trainer.SGD(...,
paddle.updater.Adam(...)),
reader=read,
k8s_user="yi",
k8s_token="kube_cluster_tls.pem",
k8s_job="hello",
num_parameter_servers=15)
```
The pseudo code if `paddle.dist_train` is as follows:
```python
def dist_train(topology, parameters, trainer, reader, ...):
if os.getenv("KUBERNETES_SERVICE_HOST") == None:
image_name = k8s_user + '/' + k8s_job
docker_build(image_name)
docker_push()
kube_ctrl_start_job(image_name, k8s_user, k8s_token)
else:
rank = kube_list_containers_in_job_and_return_current_containers_rank()
if rank == 0:
master()
elif rank < 15:
parameter_server()
else:
trainer.train(model, reader=read)
```
Please be aware that if a process is running on the Kubernetes
cluster, it will have some environment variables pre-defined.
If `dist_train` doesn't see these environment variables, it knows
that it's running on users' personal computer, and it should work as a
*launcher*. Otherwise, it knows that it's running on the cluster and
need to figure out its role as either the master, or a trainer, or a
parameter server.
......@@ -188,48 +188,6 @@ extern void hl_param_relu_backward_diff(real* grad_o,
int width,
int height,
int partial_sum);
/**
* @brief cos sim forward
*
* @param[out] output output data
* @param[in] input1 input1 data(matrix)
* @param[in] input2 input2 data(matrix or vector)
* @param[in] width matrix width
* @param[in] input1_height input1_height
* @param[in] input2_height input2_height
* @param[in] scale scale factor
*/
extern void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale);
/**
* @brief cos sim derivate
*
* @param[in] grad output grad
* @param[in] output output data
* @param[in] prevOutX input1 data
* @param[in] prevOutY input2 data
* @param[out] prevGradX input1 grad
* @param[out] prevGradY input2 grad
* @param[in] width matrix width
* @param[in] input1_height input1 height
* @param[in] input2_height input2 height
* @param[in] scale scale factor
*/
extern void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale);
/**
* @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel].
......
......@@ -74,25 +74,6 @@ inline void hl_param_relu_backward_diff(real* grad_o,
int height,
int partial_sum) {}
inline void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {}
inline void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {}
inline void hl_matrix_add_shared_bias(real* A_d,
real* B_d,
const int channel,
......
......@@ -584,177 +584,6 @@ void hl_param_relu_backward_diff(real* grad_o,
CHECK_SYNC("hl_param_relu_backward_diff failed");
}
template<int blockSize>
__global__ void KeCosSim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[blockSize];
__shared__ real yy[blockSize];
__shared__ real xy[blockSize];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
input1 += ty * width;
if (input2_height > 1) {
input2 += ty * width;
}
for (int index = tid; index < width; index += blockSize) {
real x = input1[index];
real y = input2[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
}
}
void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(input1);
CHECK_NOTNULL(input2);
const int blockSize = 256;
dim3 threads(blockSize, 1);
dim3 grid(1, input1_height);
KeCosSim<blockSize><<<grid, threads, 0, STREAM_DEFAULT>>>
(output, input1, input2, width, input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim failed");
}
template<int blockSize>
__global__ void KeCosSimDerivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[blockSize];
__shared__ real yy[blockSize];
__shared__ real xy[blockSize];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
prevOutX += ty * width;
prevGradX += ty * width;
if (input2_height > 1) {
prevOutY += ty * width;
prevGradY += ty * width;
}
for (int index = tid; index < width; index += blockSize) {
real x = prevOutX[index];
real y = prevOutY[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (xy[0] == 0) {
real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
for (int index = tid; index < width; index += blockSize) {
prevGradX[index] +=
scale * grad[ty] * prevOutY[index] * reciprocal;
if (input2_height > 1) {
prevGradY[index] +=
scale * grad[ty] * prevOutX[index] * reciprocal;
} else {
paddle::paddleAtomicAdd(prevGradY + index,
scale * grad[ty] * prevOutX[index] * reciprocal);
}
}
} else {
real reciprocalXY = 1.0 / xy[0];
real reciprocalSquareSumX = 1.0 / xx[0];
real reciprocalSquareSumY = 1.0 / yy[0];
for (int index = tid; index < width; index += blockSize) {
prevGradX[index] += output[ty] * grad[ty] *
(prevOutY[index] * reciprocalXY -
prevOutX[index] * reciprocalSquareSumX);
if (input2_height > 1) {
prevGradY[index] += output[ty] * grad[ty] *
(prevOutX[index] * reciprocalXY -
prevOutY[index] * reciprocalSquareSumY);
} else {
paddle::paddleAtomicAdd(prevGradY + index, output[ty] * grad[ty] *
(prevOutX[index] * reciprocalXY -
prevOutY[index] * reciprocalSquareSumY));
}
}
}
}
void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {
CHECK_NOTNULL(grad);
CHECK_NOTNULL(output);
CHECK_NOTNULL(prevOutX);
CHECK_NOTNULL(prevOutY);
CHECK_NOTNULL(prevGradX);
CHECK_NOTNULL(prevGradY);
const int blockSize = 256;
dim3 threads(blockSize, 1);
dim3 grid(1, input1_height);
KeCosSimDerivative<blockSize><<<grid, threads, 0, STREAM_DEFAULT>>>
(grad, output, prevOutX, prevOutY, prevGradX, prevGradY, width,
input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim_derivate failed");
}
__global__ void KeMatrixAddSharedBias(real* A,
real* B,
const int channel,
......
......@@ -190,7 +190,7 @@ public:
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
bufferType_ = TENSOR_SEQUENCE_ID;
CHECK_EQ(shape_.ndims(), 1UL);
CHECK_GT(shape_[0], 1UL);
CHECK_GE(shape_[0], 1UL);
numSeqs_ = shape_[0] - 1;
}
......@@ -226,7 +226,8 @@ public:
SequenceArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {
: BufferArg(valueType, shape, argType),
startPositions_(TensorShape({shape[0]})) {
bufferType_ = TENSOR_SEQUENCE_DATA;
}
......
......@@ -27,6 +27,7 @@ if(WITH_TESTING)
add_simple_unittest(ContextProjectionOpTest)
add_simple_unittest(PadOpTest)
add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest)
endif()
endif()
......
......@@ -108,26 +108,23 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK(1 == inputs.size() || 2 == inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK(1UL == inputs.size() || 2UL == inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(out_seq.data() && val_seqs.data() && val_seqs.getSequenceId().data());
CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
CHECK_EQ(val_seqs.shape().ndims(), (size_t)2);
CHECK_EQ(val_seqs.getSequenceId().shape().ndims(), (size_t)1);
if (2 == inputs.size()) {
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
}
CHECK_EQ(out_seq.shape().ndims(), 2UL);
CHECK_EQ(val_seqs.shape().ndims(), 2UL);
/// dim of output = dim of input * context_length
CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_);
/// input and output has the same batch_size
CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]);
/// dim of input == dim of weight
if (2 == inputs.size()) {
if (2UL == inputs.size()) {
CHECK_EQ(inputs[1].shape().ndims(), 2UL);
/// dim of input == dim of weight
CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]);
}
......@@ -135,10 +132,11 @@ public:
auto out_mat = out_seq.matrix<Device>();
const auto in_mat = val_seqs.matrix<Device>();
const auto w_mat =
(2 == inputs.size())
(2UL == inputs.size() && inputs[1].data())
? inputs[1].matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
const auto seq_vec = val_seqs.getSequenceId().vector<int, Device>();
ContextProjectionForward<Device>(out_mat,
in_mat,
w_mat,
......@@ -235,36 +233,40 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());
CHECK_EQ(1UL, inputs.size());
CHECK(1UL == outputs.size() || 2UL == outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(in_seq.data() && in_seq.getSequenceId().data());
CHECK_EQ(in_seq.shape().ndims(), (size_t)2);
CHECK_EQ(in_seq.getSequenceId().shape().ndims(), (size_t)1);
CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
CHECK_EQ(out_seq.getSequenceId().shape().ndims(), (size_t)1);
CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
CHECK_EQ(in_seq.shape().ndims(), 2UL);
CHECK_EQ(out_seq.shape().ndims(), 2UL);
CHECK_EQ(out_seq.getSequenceId().shape().ndims(), 1UL);
/// dim of input grad == dim of weight
CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
/// input and output grad has the same batch_size
CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]);
/// dim of output grad = dim of input grad * context_length
CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_);
CHECK_EQ(out_seq.getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
if (2UL == outputs.size()) {
CHECK_EQ(outputs[1].shape().ndims(), 2UL);
/// dim of input grad == dim of weight
CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
}
const auto seq_vec = in_seq.getSequenceId().vector<int, Device>();
const auto out_grad_mat = in_seq.matrix<Device>();
auto in_grad_mat =
!out_seq.data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: out_seq.matrix<Device>();
auto w_grad_mat = !outputs[1].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: outputs[1].matrix<Device>();
auto w_grad_mat =
(2UL == outputs.size() && outputs[1].data())
? outputs[1].matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
ContextProjectionBackward<Device>(out_grad_mat,
in_grad_mat,
w_grad_mat,
......@@ -304,17 +306,17 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
const auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceId().data());
CHECK_EQ(static_cast<int>(out_seq.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seq.getSequenceId().shape().ndims()), 1);
CHECK_EQ(out_seq.shape().ndims(), 2UL);
CHECK_EQ(in_seq.shape().ndims(), 2UL);
CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL);
/// output layer grad dim == input layer grad dim * context_length_
CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_);
/// input and output has the same batch_size
......@@ -355,14 +357,14 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(in_seq.data() && in_seq.getSequenceId().data() && outputs[0].data());
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seq.getSequenceId().shape().ndims()), 1);
CHECK_EQ(outputs[0].shape().ndims(), 2UL);
CHECK_EQ(in_seq.shape().ndims(), 2UL);
CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL);
CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]);
/// output layer grad dim == weight dim * context_length_
CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_);
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
......
......@@ -28,55 +28,26 @@ void testMatrixProjectionForward(int context_start,
std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false;
FunctionCompare compare("ContextProjectionForward",
FuncConfig()
.set("context_length", context_length)
.set("context_start", context_start)
.set("begin_pad", std::max(0, -context_start)));
CpuMatrix cpu_in(batch_size, input_dim);
cpu_in.randomizeUniform();
GpuMatrix gpu_in(batch_size, input_dim);
gpu_in.copyFrom(cpu_in);
auto cpu_weight =
is_padding ? std::make_shared<CpuMatrix>(pad, input_dim) : nullptr;
auto gpu_weight =
is_padding ? std::make_shared<GpuMatrix>(pad, input_dim) : nullptr;
if (is_padding) {
cpu_weight->randomizeUniform();
gpu_weight->copyFrom(*cpu_weight);
FunctionCompare test("ContextProjectionForward",
FuncConfig()
.set("context_length", context_length)
.set("context_start", context_start)
.set("begin_pad", std::max(0, -context_start)));
// prepare input arguments
test.addSequence(SequenceIdArg(TensorShape{batch_size}));
test.addInputs(
SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim}));
if (is_padding) { // weight
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{pad, input_dim}));
}
IVectorPtr cpu_seq;
generateSequenceStartPositions(batch_size, cpu_seq);
IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true);
gpu_seq->copyFrom(*cpu_seq);
CpuMatrix cpu_out(batch_size, input_dim * context_length);
GpuMatrix gpu_out(batch_size, input_dim * context_length);
cpu_out.randomizeUniform();
gpu_out.copyFrom(cpu_out);
BufferArgs cpu_inputs;
BufferArgs cpu_outputs;
cpu_inputs.addArg(cpu_in, *cpu_seq);
if (cpu_weight) {
cpu_inputs.addArg(*cpu_weight, *cpu_seq);
}
cpu_outputs.addArg(cpu_out, *cpu_seq, ADD_TO);
compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs);
test.addOutputs(
SequenceArg(VALUE_TYPE_FLOAT,
TensorShape{batch_size, input_dim * context_length}),
ADD_TO);
BufferArgs gpu_inputs;
BufferArgs gpu_outputs;
gpu_inputs.addArg(gpu_in, *gpu_seq);
if (gpu_weight) {
gpu_inputs.addArg(*gpu_weight, *gpu_seq);
}
gpu_outputs.addArg(gpu_out, *gpu_seq, ADD_TO);
compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs);
autotest::TensorCheckEqual(cpu_out, gpu_out);
// run Function
test.run();
}
void testMatrixProjectionBackward(int context_start,
......@@ -88,63 +59,31 @@ void testMatrixProjectionBackward(int context_start,
std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false;
FunctionCompare compare("ContextProjectionBackward",
FuncConfig()
.set("context_length", context_length)
.set("context_start", context_start)
.set("begin_pad", std::max(0, -context_start))
.set("is_padding", is_padding)
.set("total_pad", pad));
CpuMatrix cpu_in_grad(batch_size, input_dim);
cpu_in_grad.randomizeUniform();
GpuMatrix gpu_in_grad(batch_size, input_dim);
gpu_in_grad.copyFrom(cpu_in_grad);
CpuMatrix cpu_out_grad(batch_size, input_dim * context_length);
cpu_out_grad.randomizeUniform();
GpuMatrix gpu_out_grad(batch_size, input_dim * context_length);
gpu_out_grad.copyFrom(cpu_out_grad);
IVectorPtr cpu_seq;
generateSequenceStartPositions(batch_size, cpu_seq);
IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true);
gpu_seq->copyFrom(*cpu_seq);
auto cpu_w_grad =
is_padding ? std::make_shared<CpuMatrix>(pad, input_dim) : nullptr;
auto gpu_w_grad =
is_padding ? std::make_shared<GpuMatrix>(pad, input_dim) : nullptr;
if (is_padding) {
cpu_w_grad->randomizeUniform();
gpu_w_grad->copyFrom(*cpu_w_grad);
FunctionCompare test("ContextProjectionBackward",
FuncConfig()
.set("context_length", context_length)
.set("context_start", context_start)
.set("begin_pad", std::max(0, -context_start))
.set("is_padding", is_padding)
.set("total_pad", pad));
// prepare input arguments
test.addSequence(SequenceIdArg(TensorShape{batch_size}));
test.addInputs(SequenceArg(
VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim * context_length}));
test.addOutputs(
SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim}),
ADD_TO);
if (is_padding) { // weight
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{pad, input_dim}),
ADD_TO);
}
BufferArgs cpu_inputs;
BufferArgs cpu_outputs;
cpu_inputs.addArg(cpu_out_grad, *cpu_seq);
cpu_outputs.addArg(cpu_in_grad, *cpu_seq, ADD_TO);
cpu_outputs.addArg(
cpu_w_grad ? *cpu_w_grad : CpuMatrix(nullptr, 0, input_dim), ADD_TO);
compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs);
BufferArgs gpu_inputs;
BufferArgs gpu_outputs;
gpu_inputs.addArg(gpu_out_grad, *gpu_seq);
gpu_outputs.addArg(gpu_in_grad, *gpu_seq, ADD_TO);
gpu_outputs.addArg(
gpu_w_grad ? *gpu_w_grad : GpuMatrix(nullptr, 0, input_dim), ADD_TO);
compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs);
autotest::TensorCheckErr(cpu_in_grad, gpu_in_grad);
if (is_padding) {
autotest::TensorCheckErr(*cpu_w_grad, *gpu_w_grad);
}
// run Function
test.run();
}
TEST(ContextProjection, projection) {
TEST(ContextProjection, Projection) {
for (auto context_start : {-5, -3, -1, 0, 3}) {
for (auto context_length : {1, 2, 5, 7}) {
for (auto trainable_padding : {false, true}) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CosSimOp.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Vector.h"
namespace paddle {
/**
* Cosine Similarity for CpuMatrix
*
* \param out_mat, output value, size: nSamples * 1.
* \param in1_mat, input value 1, size: nSamples * dim.
* \param in2_mat, input value 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param scale, default 1.0
*
*/
template <>
void CosSimForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
const CpuMatrix& in1_mat,
const CpuMatrix& in2_mat,
real scale) {
CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
size_t num_samples = out_mat.getHeight();
size_t dim = in1_mat.getWidth();
/// column vector [nSamples, 1]
real* out = out_mat.getData();
const real* x = in1_mat.getData();
const real* y = in2_mat.getData();
/// in2 might only have one row or full rows
CHECK(in2_mat.getHeight() == 1LU || in2_mat.getHeight() == num_samples);
size_t inc = (in2_mat.getHeight() == 1LU) ? 0 : dim;
for (size_t i = 0; i < num_samples; ++i, x += dim, y += inc) {
real square_sum_x = 0;
real square_sum_y = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
square_sum_x += x[j] * x[j];
square_sum_y += y[j] * y[j];
xy += x[j] * y[j];
}
CHECK(square_sum_x > 0 && square_sum_y > 0);
out[i] = scale * xy / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
}
}
/**
* Cosine Similarity
* for each row i,
* out[i] = scale * cos(input1[i], input2[i])
* = scale * <input1[i], input2[i]>/sqrt(|input1[i]|^2 * |input2[i]|^2)
* when input2 only has one row, then for each row i,
* out[i] = cos(input1[i], input2[0])
*
* \param inputs[0] input matrix 1, size: nSamples * dim.
* \param inputs[1] input matrix 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param outputs[0] output matrix, size : nSamples * 1.
*/
template <DeviceType Device>
class CosSimForwardFunc : public FunctionBase {
void init(const FuncConfig& config) override {
scale_ = config.get<real>("scale");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(inputs.size(), 2UL);
CHECK_EQ(outputs.size(), 1UL);
CHECK_EQ(inputs[0].shape().ndims(), 2UL);
CHECK_EQ(inputs[1].shape().ndims(), 2UL);
CHECK_EQ(outputs[0].shape().ndims(), 2UL);
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
CHECK_EQ(outputs[0].shape()[1], 1UL);
CHECK(outputs[0].data() && inputs[0].data() && inputs[1].data());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
auto out_mat = outputs[0].matrix<Device>();
const auto in1_mat = inputs[0].matrix<Device>();
const auto in2_mat = inputs[1].matrix<Device>();
CosSimForward<Device>(out_mat, in1_mat, in2_mat, scale_);
}
private:
real scale_;
};
/**
* Cosine Similarity Derivative for CpuMatrix
*
* \param in1_grad forward input grad 1, size: nSamples * dim.
* \param in2_grad forward input grad 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*
* \param out_grad backward loss output grad, size : nSamples * 1.
* \param out_val forward output value, size: nSamples * 1.
* \param in1_val forward input value 1, size: nSamples * dim.
* \param in2_val forward input value 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param scale, default 1.0
*/
template <>
void CosSimBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad,
const CpuMatrix& out_val,
const CpuMatrix& in1_val,
const CpuMatrix& in2_val,
CpuMatrix& in1_grad,
CpuMatrix& in2_grad,
real scale) {
CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
in2_val.getData() && in1_grad.getData() && in2_grad.getData());
CHECK_EQ(out_val.useGpu_, false) << "Matrix type are GPU, CPU required";
const real* grad = out_grad.getData();
const real* out = out_val.getData();
const real* prev_out_x = in1_val.getData();
const real* prev_out_y = in2_val.getData();
real* prev_grad_x = in1_grad.getData();
real* prev_grad_y = in2_grad.getData();
size_t num_samples = out_grad.getHeight();
size_t dim = in1_val.getWidth();
CHECK_EQ(in2_val.getHeight(), in2_grad.getHeight());
CHECK(in2_val.getHeight() == 1LU || in2_val.getHeight() == num_samples);
size_t inc = (in2_val.getHeight() == 1LU) ? 0 : dim;
for (size_t i = 0; i < num_samples; ++i,
prev_out_x += dim,
prev_out_y += inc,
prev_grad_x += dim,
prev_grad_y += inc) {
real square_sum_x = 0;
real square_sum_y = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
square_sum_x += prev_out_x[j] * prev_out_x[j];
square_sum_y += prev_out_y[j] * prev_out_y[j];
xy += prev_out_x[j] * prev_out_y[j];
}
CHECK(square_sum_x > 0 && square_sum_y > 0);
if (xy == 0) {
real reciprocal =
1.0f / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] += scale * grad[i] * prev_out_y[j] * reciprocal;
prev_grad_y[j] += scale * grad[i] * prev_out_x[j] * reciprocal;
}
} else {
real reciprocal_xy = 1.0f / xy;
real reciprocal_square_sum_x = 1.0f / square_sum_x;
real reciprocal_square_sum_y = 1.0f / square_sum_y;
for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] +=
out[i] * grad[i] * (prev_out_y[j] * reciprocal_xy -
prev_out_x[j] * reciprocal_square_sum_x);
prev_grad_y[j] +=
out[i] * grad[i] * (prev_out_x[j] * reciprocal_xy -
prev_out_y[j] * reciprocal_square_sum_y);
}
}
}
}
/**
* Cosine Similarity backward Derivative
*
* \param outputs[0] forward input grad 1, size: nSamples * dim.
* \param outputs[1] forward input grad 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*
* \param inputs[0] backward loss output grad, size : nSamples * 1.
* \param inputs[1] forward output value, size: nSamples * 1.
* \param inputs[2] forward input value 1, size: nSamples * dim.
* \param inputs[3] forward input value 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*/
template <DeviceType Device>
class CosSimBackwardFunc : public FunctionBase {
void init(const FuncConfig& config) override {
scale_ = config.get<real>("scale");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(inputs.size(), 4UL);
CHECK_EQ(outputs.size(), 2UL);
/// dim of out_grad and out_val == 1, column vector
CHECK_EQ(inputs[0].shape()[1], 1UL);
CHECK_EQ(inputs[1].shape()[1], 1UL);
/// nSamples of out_grad == out_val == in_val1 == in_grad1
CHECK_EQ(inputs[1].shape()[0], inputs[0].shape()[0]);
CHECK_EQ(inputs[0].shape()[0], inputs[0].shape()[0]);
CHECK_EQ(outputs[0].shape()[0], inputs[0].shape()[0]);
/// dim of in1_val1 == in_val2 == in_grad1 == in_grad2
CHECK_EQ(inputs[3].shape()[1], inputs[2].shape()[1]);
CHECK_EQ(outputs[0].shape()[1], inputs[2].shape()[1]);
CHECK_EQ(outputs[1].shape()[1], inputs[2].shape()[1]);
CHECK(inputs[0].data() && inputs[1].data() && inputs[2].data() &&
inputs[3].data() && outputs[0].data() && outputs[1].data());
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
const auto out_grad = inputs[0].matrix<Device>();
const auto out_val = inputs[1].matrix<Device>();
const auto in1_val = inputs[2].matrix<Device>();
const auto in2_val = inputs[3].matrix<Device>();
auto in1_grad = outputs[0].matrix<Device>();
auto in2_grad = outputs[1].matrix<Device>();
CosSimBackward<Device>(
out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_);
}
private:
real scale_;
};
REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief Cosine Similarity Forward.
* for each row i,
* out[i] = scale * cos(in1[i], in2[i])
* = scale * \sum_j (in1[i][j] * in2[i][j]) /
* sqrt(sum_j (in1[i][j]^2) * sum_j (in2[i][j])^2)
*
* \param[out] output output value.
* \param[in] intput1 input value.
* \param[in] intput2 input value.
* \param[in] scale default 1.0.
*
*/
template <DeviceType Device>
void CosSimForward(typename Tensor<real, Device>::Matrix& output,
const typename Tensor<real, Device>::Matrix& input1,
const typename Tensor<real, Device>::Matrix& input2,
real scale);
/**
* \brief Cosine Similarity BackWard for Derivative.
*
* \param[in] output grad backward loss output grad.
* \param[in] output val forward-output value.
* \param[in] input val1 forward input value 1.
* \param[in] input val2 forward input value 2.
* \param[in/out] input grad forward input grad 1.
* \param[in/out] input grad forward input grad 2.
* \param[in] scale default 1.0.
*
*/
template <DeviceType Device>
void CosSimBackward(const typename Tensor<real, Device>::Matrix& out_grad,
const typename Tensor<real, Device>::Matrix& out_value,
const typename Tensor<real, Device>::Matrix& in1_value,
const typename Tensor<real, Device>::Matrix& in2_value,
typename Tensor<real, Device>::Matrix& in1_grad,
typename Tensor<real, Device>::Matrix& in2_grad,
real scale);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "hl_device_functions.cuh"
#include "CosSimOp.h"
namespace paddle {
template<int block_size>
__global__ void KeCosSim(real* output,
const real* input1,
const real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[block_size];
__shared__ real yy[block_size];
__shared__ real xy[block_size];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
input1 += ty * width;
if (input2_height > 1) {
input2 += ty * width;
}
for (int index = tid; index < width; index += block_size) {
real x = input1[index];
real y = input2[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
}
}
void hlCossim(real* output,
const real* input1,
const real* input2,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(input1);
CHECK_NOTNULL(input2);
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, input1_height);
KeCosSim<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
(output, input1, input2, width, input1_height, input2_height, scale);
CHECK_SYNC("hlCossim failed");
}
template <>
void CosSimForward<DEVICE_TYPE_GPU>(GpuMatrix& out_mat,
const GpuMatrix& in1_mat,
const GpuMatrix& in2_mat,
real scale) {
CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
CHECK(in1_mat.useGpu_ == true && in2_mat.useGpu_ == true)
<< "Matrix type are not GPU";
size_t num_samples = out_mat.getHeight();
size_t dim = in1_mat.getWidth();
real* out = out_mat.getData();
const real* x = in1_mat.getData();
const real* y = in2_mat.getData();
hlCossim(out, x, y, dim, in1_mat.getHeight(), in2_mat.getHeight(), scale);
}
template<int block_size>
__global__ void KeCosSimDerivative(const real* grad,
const real* output,
const real* prev_out_x,
const real* prev_out_y,
real* prev_grad_x,
real* prev_grad_y,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[block_size];
__shared__ real yy[block_size];
__shared__ real xy[block_size];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
prev_out_x += ty * width;
prev_grad_x += ty * width;
if (input2_height > 1) {
prev_out_y += ty * width;
prev_grad_y += ty * width;
}
for (int index = tid; index < width; index += block_size) {
real x = prev_out_x[index];
real y = prev_out_y[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (xy[0] == 0) {
real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
for (int index = tid; index < width; index += block_size) {
prev_grad_x[index] +=
scale * grad[ty] * prev_out_y[index] * reciprocal;
if (input2_height > 1) {
prev_grad_y[index] +=
scale * grad[ty] * prev_out_x[index] * reciprocal;
} else {
paddle::paddleAtomicAdd(prev_grad_y + index,
scale * grad[ty] * prev_out_x[index] * reciprocal);
}
}
} else {
real reciprocalXY = 1.0 / xy[0];
real reciprocalSquareSumX = 1.0 / xx[0];
real reciprocalSquareSumY = 1.0 / yy[0];
for (int index = tid; index < width; index += block_size) {
prev_grad_x[index] += output[ty] * grad[ty] *
(prev_out_y[index] * reciprocalXY -
prev_out_x[index] * reciprocalSquareSumX);
if (input2_height > 1) {
prev_grad_y[index] += output[ty] * grad[ty] *
(prev_out_x[index] * reciprocalXY -
prev_out_y[index] * reciprocalSquareSumY);
} else {
paddle::paddleAtomicAdd(prev_grad_y + index, output[ty] * grad[ty] *
(prev_out_x[index] * reciprocalXY -
prev_out_y[index] * reciprocalSquareSumY));
}
}
}
}
void hlCossimDerivative(const real* grad,
const real* output,
const real* prev_out_x,
const real* prev_out_y,
real* prev_grad_x,
real* prev_grad_y,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
CHECK_NOTNULL(grad);
CHECK_NOTNULL(output);
CHECK_NOTNULL(prev_out_x);
CHECK_NOTNULL(prev_out_y);
CHECK_NOTNULL(prev_grad_x);
CHECK_NOTNULL(prev_grad_y);
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, input1_height);
KeCosSimDerivative<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
(grad, output, prev_out_x, prev_out_y, prev_grad_x, prev_grad_y, width,
input1_height, input2_height, scale);
CHECK_SYNC("hlCossimDerivate failed");
}
template <>
void CosSimBackward<DEVICE_TYPE_GPU>(const GpuMatrix& out_grad,
const GpuMatrix& out_val,
const GpuMatrix& in1_val,
const GpuMatrix& in2_val,
GpuMatrix& in1_grad,
GpuMatrix& in2_grad,
real scale) {
CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
in2_val.getData() && in1_grad.getData() && in2_grad.getData());
CHECK(out_grad.useGpu_ && out_val.useGpu_ && in1_val.useGpu_
&& in2_val.useGpu_ && in1_grad.useGpu_ && in2_grad.useGpu_)
<< "Matrix types are not equally GPU";
size_t dim = in1_val.getWidth();
const real* grad = out_grad.getData();
const real* out = out_val.getData();
const real* prev_out_x = in1_val.getData();
const real* prev_out_y = in2_val.getData();
real* prev_grad_x = in1_grad.getData();
real* prev_grad_y = in2_grad.getData();
hlCossimDerivative(grad,
out,
prev_out_x,
prev_out_y,
prev_grad_x,
prev_grad_y,
dim,
in1_val.getHeight(),
in2_val.getHeight(),
scale);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
#include "paddle/math/Matrix.h"
using namespace paddle; // NOLINT
void testCosSimForward(size_t height_x,
size_t height_y,
size_t width,
real scale) {
FunctionCompare test("CosSimForward", FuncConfig().set("scale", scale));
// prepare input arguments
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}),
ASSIGN_TO);
// run Function
test.run();
}
void testCosSimBackward(size_t height_x,
size_t height_y,
size_t width,
real scale) {
FunctionCompare test("CosSimBackward", FuncConfig().set("scale", scale));
// prepare input arguments
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}),
ADD_TO);
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}),
ADD_TO);
// run Function
test.run();
}
TEST(Matrix, cosSim) {
for (auto height_x : {10, 100, 1000}) {
for (auto height_y : {1, height_x}) {
for (auto width : {10, 100, 1000}) {
for (auto scale : {1.0, 2.0}) {
testCosSimForward(height_x, height_y, width, scale);
testCosSimBackward(height_x, height_y, width, scale);
}
}
}
}
}
......@@ -69,6 +69,54 @@ public:
gpuMemory_.back()->getBuf(), input.valueType(), input.shape()));
}
// assume one copy of sequence is shared by different SequenceArgs
void addSequence(const SequenceIdArg& input) {
CHECK_EQ(input.shape().ndims(), 1UL);
size_t batchSize = input.shape()[0];
size_t numSeqs = batchSize / 10 + 1;
size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32);
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(sizeId));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(sizeId));
cpuSeq_ = std::make_shared<SequenceIdArg>(cpuMemory_.back()->getBuf(),
TensorShape{numSeqs + 1});
gpuSeq_ = std::make_shared<SequenceIdArg>(gpuMemory_.back()->getBuf(),
TensorShape{numSeqs + 1});
/// init sequence Id
initArg(*cpuSeq_, batchSize);
// todo(tianbing), delete it
CHECK_EQ(cpuSeq_->shape().getElements(), cpuSeq_->numSeqs() + 1);
CpuIVector cpuSeq(cpuSeq_->shape().getElements(), (int*)cpuSeq_->data());
GpuIVector gpuSeq(gpuSeq_->shape().getElements(), (int*)gpuSeq_->data());
gpuSeq.copyFrom(cpuSeq);
}
void addInputs(const SequenceArg& input) {
CHECK_EQ(input.shape().ndims(), 2UL);
size_t batchSize = input.shape()[0];
if (!cpuSeq_ || !gpuSeq_) { // sequence not exist
addSequence(SequenceIdArg(TensorShape{batchSize}));
}
size_t size =
input.shape().getElements() * sizeOfValuType(input.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size));
/// SequenceArg
cpuInputs_.emplace_back(
std::make_shared<SequenceArg>(cpuMemory_.back()->getBuf(),
input.valueType(),
input.shape(),
*cpuSeq_));
gpuInputs_.emplace_back(
std::make_shared<SequenceArg>(gpuMemory_.back()->getBuf(),
input.valueType(),
input.shape(),
*gpuSeq_));
}
// output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size =
......@@ -116,24 +164,31 @@ public:
std::make_shared<SparseMatrixArg>(*gpuSparse_, argType));
}
void addInputs(const SequenceArg& input) {
size_t batchSize = input.shape()[0];
size_t numSeqs = batchSize / 10 + 1;
size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32);
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(sizeId));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(sizeId));
TensorShape seqsId({numSeqs + 1});
// void* cpuBuffer = cpuMemory_.back()->getBuf();
// void* gpuBuffer = gpuMemory_.back()->getBuf();
void addOutputs(const SequenceArg& output, ArgType argType = ASSIGN_TO) {
CHECK_EQ(output.shape().ndims(), 2UL);
size_t batchSize = output.shape()[0];
if (!cpuSeq_ || !gpuSeq_) { // sequence not exist
addSequence(SequenceIdArg(TensorShape{batchSize}));
}
size_t size =
input.shape().getElements() * sizeOfValuType(input.valueType());
output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size));
// TODO: need be implemented.
/// SequenceArg
cpuOutputs_.emplace_back(
std::make_shared<SequenceArg>(cpuMemory_.back()->getBuf(),
output.valueType(),
output.shape(),
*cpuSeq_,
argType));
gpuOutputs_.emplace_back(
std::make_shared<SequenceArg>(gpuMemory_.back()->getBuf(),
output.valueType(),
output.shape(),
*gpuSeq_,
argType));
}
void addInputs(const SparseMatrixArg& input) {
......@@ -193,14 +248,44 @@ public:
std::shared_ptr<FunctionBase> getGpuFunction() const { return gpuFunc_; }
protected:
// only init cpu argument, gpu argument copy from cpu argument.
void initArg(BufferArg& arg) {
CpuVector vector(arg.shape().getElements(), (real*)arg.data());
vector.uniform(0.001, 1);
}
void initArg(SequenceArg& arg) {
/// init only matrix
CpuVector vector(arg.shape().getElements(), (real*)arg.data());
vector.uniform(0.001, 1);
}
void initArg(SequenceIdArg& arg, size_t batchSize) {
size_t numSeqs = arg.numSeqs();
int* buf = reinterpret_cast<int*>(arg.data());
int pos = 0;
size_t maxLen = 2 * batchSize / numSeqs;
for (int i = 0; i < (int)numSeqs; ++i) {
int len = 1 + uniformRandom(std::min<int64_t>(
maxLen, batchSize - pos - numSeqs + i));
buf[i] = pos;
pos += len;
VLOG(1) << " len=" << len;
}
buf[numSeqs] = batchSize;
}
void initInputs() {
for (size_t i = 0; i < cpuInputs_.size(); i++) {
if (cpuInputs_[i]->isSparseArg()) {
continue; /// sparse matrix already init
}
initArg(*cpuInputs_[i]);
if (cpuInputs_[i]->isSequenceArg()) {
initArg(dynamic_cast<SequenceArg&>(*cpuInputs_[i]));
} else {
initArg(*cpuInputs_[i]);
}
// TODO: Need a BufferCopy used to copy from one BufferArg to another.
CpuVector cpuVector(cpuInputs_[i]->shape().getElements(),
(real*)cpuInputs_[i]->data());
......@@ -217,7 +302,11 @@ protected:
continue; /// sparse matrix already init
}
initArg(*cpuOutputs_[i]);
if (cpuOutputs_[i]->isSequenceArg()) {
initArg(dynamic_cast<SequenceArg&>(*cpuOutputs_[i]));
} else {
initArg(*cpuOutputs_[i]);
}
// TODO: Need a BufferCopy used to copy from one BufferArg to another.
CpuVector cpuVector(cpuOutputs_[i]->shape().getElements(),
......@@ -241,28 +330,6 @@ protected:
}
}
// only init cpu argument, gpu argument copy from cpu argument.
void initArg(BufferArg& arg) {
CpuVector vector(arg.shape().getElements(), (real*)arg.data());
vector.uniform(0.001, 1);
}
void initArg(SequenceIdArg& arg, size_t batchSize) {
size_t numSeqs = arg.numSeqs();
int* buf = reinterpret_cast<int*>(arg.data());
int pos = 0;
size_t maxLen = 2 * batchSize / numSeqs;
for (int i = 0; i < (int)numSeqs; ++i) {
int len = uniformRandom(
std::min<int64_t>(maxLen, batchSize - pos - numSeqs + i)) +
1;
buf[i] = pos;
pos += len;
VLOG(1) << " len=" << len;
}
buf[numSeqs] = batchSize;
}
protected:
std::shared_ptr<FunctionBase> cpuFunc_;
std::shared_ptr<FunctionBase> gpuFunc_;
......@@ -274,6 +341,8 @@ protected:
std::vector<BufferArgPtr> gpuOutputs_;
std::shared_ptr<CpuSparseMatrix> cpuSparse_;
std::shared_ptr<GpuSparseMatrix> gpuSparse_;
std::shared_ptr<SequenceIdArg> cpuSeq_;
std::shared_ptr<SequenceIdArg> gpuSeq_;
};
} // namespace paddle
......@@ -60,7 +60,7 @@ TEST(MulOp, DDDMatrixMul) {
if (transa && transb) {
continue;
}
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " transa=" << transa << " transb=" << transb
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
......@@ -104,7 +104,7 @@ TEST(MuLOp, DSparseDMul) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSR}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
......@@ -150,7 +150,7 @@ TEST(MulOp, DDSparseMul) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSR, SPARSE_CSC}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
......@@ -197,7 +197,7 @@ TEST(MulOp, SparseDDMul) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSC, SPARSE_CSR}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
......
......@@ -647,7 +647,7 @@ public:
DataBatch& gpuBatch = *batch;
std::vector<Argument>& gpuArguments = gpuBatch.getStreams();
gpuArguments.resize(cpuArguments.size());
gpuBatch.setSize(size);
gpuBatch.setSize(bsize);
for (size_t i = 0; i < headers_.size(); ++i) {
gpuArguments[i].resizeAndCopyFrom(
cpuArguments[i], useGpu_, HPPL_STREAM_1);
......
......@@ -26,15 +26,23 @@ bool CosSimLayer::init(const LayerMap& layerMap,
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2LU);
createFunction(forward_,
"CosSimForward",
FuncConfig().set("scale", (real)config_.cos_scale()));
createFunction(backward_,
"CosSimBackward",
FuncConfig().set("scale", (real)config_.cos_scale()));
return true;
}
void CosSimLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
int batchSize = getInputValue(0)->getHeight();
int size = getSize();
CHECK_EQ(forward_.size(), 1) << "Only one forward function needed";
{
REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str());
......@@ -42,26 +50,43 @@ void CosSimLayer::forward(PassType passType) {
}
MatrixPtr outV = getOutputValue();
/* activation */ {
REGISTER_TIMER_INFO("CosFwAtvTimer", getName().c_str());
MatrixPtr prevOut1 = getInputValue(0);
MatrixPtr prevOut2 = getInputValue(1);
outV->cosSim(*prevOut1, *prevOut2, config_.cos_scale());
CHECK(outV && prevOut1 && prevOut2);
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*prevOut1);
inputs.addArg(*prevOut2);
outputs.addArg(*outV, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
}
void CosSimLayer::backward(const UpdateCallback& callback) {
/* activation */ {
REGISTER_TIMER_INFO("CosBpAtvTimer", getName().c_str());
MatrixPtr outG = this->getOutputGrad();
outG->cosSimDerivative(*this->getOutputValue(),
*getInputValue(0),
*getInputValue(1),
*getInputGrad(0),
*getInputGrad(1),
config_.cos_scale());
CHECK_EQ(backward_.size(), 1) << "Only one backward function needed";
const auto outG = this->getOutputGrad();
const auto outV = this->getOutputValue();
const auto inV1 = this->getInputValue(0);
const auto inV2 = this->getInputValue(1);
auto inG1 = this->getInputGrad(0);
auto inG2 = this->getInputGrad(1);
CHECK(outG && outV && inV1 && inV2 && inG1 && inG2);
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*outG);
inputs.addArg(*outV);
inputs.addArg(*inV1);
inputs.addArg(*inV2);
outputs.addArg(*inG1, ADD_TO);
outputs.addArg(*inG2, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
}
......
......@@ -28,7 +28,7 @@ namespace paddle {
*
* - Input1: A vector (batchSize * dataDim) *
* - Input2: A vector (batchSize * dataDim) or (1 * dataDim) *
* - Output: A vector (dataDim * 1)
* - Output: A vector (batchSize * 1)
*
* The config file api is cos_sim.
*/
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/utils/Stat.h"
namespace paddle {
/**
* @brief A layer for computing cosine similarity between a vector
* and each row of a matrix
......@@ -98,11 +97,22 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap,
dataDim,
/* trans= */ false,
useGpu_);
CHECK(tmpRow0 && tmpRow1 && tmpRow2 && tmpRow3 && tmpMtx0 && tmpMtx1);
createFunction(forward_,
"CosSimForward",
FuncConfig().set("scale", (real)config_.cos_scale()));
createFunction(backward_,
"CosSimBackward",
FuncConfig().set("scale", (real)config_.cos_scale()));
return true;
}
void CosSimVecMatLayer::forward(PassType passType) {
Layer::forward(passType);
CHECK_EQ(forward_.size(), 1) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
......@@ -118,17 +128,25 @@ void CosSimVecMatLayer::forward(PassType passType) {
}
MatrixPtr outV = getOutputValue();
CHECK(outV && inV0 && inV1);
REGISTER_TIMER_INFO("FwCosVMTimer", getName().c_str());
for (size_t i = 0; i < batchSize; i++) {
tmpRow0->setData(inV0->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i));
tmpRow2->cosSim(*(tmpMtx0), *(tmpRow0), config_.cos_scale());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*tmpMtx0);
inputs.addArg(*tmpRow0);
outputs.addArg(*tmpRow2, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
}
void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
CHECK_EQ(backward_.size(), 1) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr inG0 = getInputGrad(0);
......@@ -137,27 +155,27 @@ void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
MatrixPtr outG = getOutputGrad();
size_t batchSize = inV0->getHeight();
CHECK(inV0 && inV1 && inG0 && inG1 && outV && outG);
REGISTER_TIMER_INFO("BwCosVMTimer", getName().c_str());
if (inG0 && inG1) {
for (size_t i = 0; i < batchSize; i++) {
tmpRow0->setData(inV0->rowBuf(i));
tmpRow1->setData(inG0->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i));
tmpMtx1->setData(inG1->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i));
tmpRow3->setData(outG->rowBuf(i));
tmpRow3->cosSimDerivative(*(tmpRow2),
*(tmpMtx0),
*(tmpRow0),
*(tmpMtx1),
*(tmpRow1),
config_.cos_scale());
}
} else {
CHECK(!inG0 || !inG1) << "Not supported";
for (size_t i = 0; i < batchSize; i++) {
tmpRow0->setData(inV0->rowBuf(i));
tmpRow1->setData(inG0->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i));
tmpMtx1->setData(inG1->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i));
tmpRow3->setData(outG->rowBuf(i));
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*tmpRow3);
inputs.addArg(*tmpRow2);
inputs.addArg(*tmpMtx0);
inputs.addArg(*tmpRow0);
outputs.addArg(*tmpMtx1, ADD_TO);
outputs.addArg(*tmpRow1, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
}
......
......@@ -941,59 +941,6 @@ void GpuMatrix::softreluDerivative(Matrix& output) {
void GpuMatrix::scaledTanh(Matrix& output, real p1, real p2) {
BaseMatrix::scaledTanh(output, p1, p2);
}
void GpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) {
CHECK(output1.useGpu_ == true && output2.useGpu_ == true)
<< "Matrix type are not equal";
size_t numSamples = getHeight();
size_t dim = output1.getWidth();
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output1.getHeight(), numSamples);
CHECK_EQ(output1.getWidth(), output2.getWidth());
real* out = getData();
real* x = output1.getData();
real* y = output2.getData();
hl_cossim(out, x, y, dim, output1.getHeight(), output2.getHeight(), scale);
}
void GpuMatrix::cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale) {
CHECK(output.useGpu_ == true && prevOut1.useGpu_ == true &&
prevOut2.useGpu_ == true && prevGrad1.useGpu_ == true &&
prevGrad2.useGpu_ == true)
<< "Matrix type are not equal";
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output.getWidth(), 1UL);
size_t numSamples = getHeight();
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(prevOut1.getHeight(), numSamples);
CHECK_EQ(prevGrad1.getHeight(), numSamples);
size_t dim = prevOut1.getWidth();
CHECK_EQ(prevOut2.getWidth(), dim);
CHECK_EQ(prevGrad1.getWidth(), dim);
CHECK_EQ(prevGrad2.getWidth(), dim);
real* grad = getData();
real* out = output.getData();
real* prevOutX = prevOut1.getData();
real* prevOutY = prevOut2.getData();
real* prevGradX = prevGrad1.getData();
real* prevGradY = prevGrad2.getData();
hl_cossim_derivative(grad,
out,
prevOutX,
prevOutY,
prevGradX,
prevGradY,
dim,
prevOut1.getHeight(),
prevOut2.getHeight(),
scale);
}
void GpuMatrix::randomizeUniform() {
CHECK(isContiguous());
......@@ -3470,105 +3417,6 @@ void CpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) {
}
}
void CpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) {
size_t numSamples = getHeight();
size_t dim = output1.getWidth();
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output1.getHeight(), numSamples);
CHECK_EQ(output1.getWidth(), output2.getWidth());
real* out = getData();
const real* x = output1.getData();
const real* y = output2.getData();
size_t yInc = dim;
if (output2.getHeight() == 1LU) {
yInc = 0;
} else {
CHECK_EQ(output2.getHeight(), numSamples);
}
for (size_t i = 0; i < numSamples; ++i, x += dim, y += yInc) {
real squareSumX = 0;
real squareSumY = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
squareSumX += _square(x[j]);
squareSumY += _square(y[j]);
xy += x[j] * y[j];
}
CHECK(squareSumX > 0 && squareSumY > 0);
out[i] = scale * xy / (std::sqrt(squareSumX) * std::sqrt(squareSumY));
}
}
void CpuMatrix::cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale) {
CHECK(output.useGpu_ == false) << "Matrix type are not equal";
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output.getWidth(), 1UL);
size_t numSamples = getHeight();
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(prevOut1.getHeight(), numSamples);
CHECK_EQ(prevGrad1.getHeight(), numSamples);
size_t dim = prevOut1.getWidth();
CHECK_EQ(prevOut2.getWidth(), dim);
CHECK_EQ(prevGrad1.getWidth(), dim);
CHECK_EQ(prevGrad2.getWidth(), dim);
const real* grad = getData();
const real* out = output.getData();
const real* prevOutX = prevOut1.getData();
const real* prevOutY = prevOut2.getData();
real* prevGradX = prevGrad1.getData();
real* prevGradY = prevGrad2.getData();
size_t yInc = dim;
if (prevOut2.getHeight() == 1LU) {
yInc = 0;
CHECK_EQ(prevGrad2.getHeight(), 1LU);
} else {
CHECK_EQ(prevOut2.getHeight(), numSamples);
CHECK_EQ(prevGrad2.getHeight(), numSamples);
}
for (size_t i = 0; i < numSamples; ++i,
prevOutX += dim,
prevOutY += yInc,
prevGradX += dim,
prevGradY += yInc) {
real squareSumX = 0;
real squareSumY = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
squareSumX += _square(prevOutX[j]);
squareSumY += _square(prevOutY[j]);
xy += prevOutX[j] * prevOutY[j];
}
CHECK(squareSumX > 0 && squareSumY > 0);
if (xy == 0) {
real reciprocal = 1.0f / (std::sqrt(squareSumX) * std::sqrt(squareSumY));
for (size_t j = 0; j < dim; ++j) {
prevGradX[j] += scale * grad[i] * prevOutY[j] * reciprocal;
prevGradY[j] += scale * grad[i] * prevOutX[j] * reciprocal;
}
} else {
real reciprocalXY = 1.0f / xy;
real reciprocalSquareSumX = 1.0f / squareSumX;
real reciprocalSquareSumY = 1.0f / squareSumY;
for (size_t j = 0; j < dim; ++j) {
prevGradX[j] += out[i] * grad[i] * (prevOutY[j] * reciprocalXY -
prevOutX[j] * reciprocalSquareSumX);
prevGradY[j] += out[i] * grad[i] * (prevOutX[j] * reciprocalXY -
prevOutY[j] * reciprocalSquareSumY);
}
}
}
}
void CpuMatrix::sumOfSquares(Matrix& output, Matrix& label) {
CHECK(output.useGpu_ == false && label.useGpu_ == false)
<< "Matrix type are not equal";
......
......@@ -799,26 +799,6 @@ public:
LOG(FATAL) << "Not implemented";
}
/**
* cosine similarity, for each row i,
* this[i] = cos(output1[i], output2[i])
*
* output2 can only have one row, then for each row i,
* this[i] = cos(output1[i], output2[0])
*/
virtual void cosSim(Matrix& output1, Matrix& output2, real scale = 1.0f) {
LOG(FATAL) << "Not implemented";
}
virtual void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale = 1.0f) {
LOG(FATAL) << "Not implemented";
}
/// print out the values of elements to os
virtual void print(std::ostream& os) const {
LOG(FATAL) << "Not implemented";
......@@ -1324,14 +1304,6 @@ public:
void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2);
void cosSim(Matrix& output1, Matrix& output2, real scale);
void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale);
virtual void print(std::ostream& os) const;
virtual void print(std::ostream& os, size_t height, size_t width) const;
......@@ -1752,14 +1724,6 @@ public:
void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2);
void cosSim(Matrix& output1, Matrix& output2, real scale);
void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale);
void print(std::ostream& os) const;
void print(std::ostream& os, size_t height, size_t width) const;
void printOneRow(std::ostream& os, size_t idx) const;
......
......@@ -181,28 +181,6 @@ TEST(Matrix, copyByRowIndex) {
}
}
void testCosSim(int heightX, int heightY, int width, real scale) {
AutoCompare test(heightX, 1);
CpuMatrix arg1(heightX, width);
CpuMatrix arg2(heightY, width);
arg1.randomizeUniform();
arg2.randomizeUniform();
arg2.add(-0.5);
test.cmpWithArg(&Matrix::cosSim, arg1, arg2, scale);
}
TEST(Matrix, cosSim) {
for (auto heightX : {10, 100, 1000}) {
for (auto heightY : {1, heightX}) {
for (auto width : {10, 100, 1000}) {
for (auto scale : {1.0, 2.0}) {
testCosSim(heightX, heightY, width, scale);
}
}
}
}
}
void testParamReluForward(int height, int width, int w_height, int w_width) {
AutoCompare test(height, width);
CpuMatrix arg1(height, width);
......
......@@ -720,61 +720,6 @@ TEST(Matrix, sequenceAvgForward) {
}
}
void testCosSimDerivate(int heightX, int heightY, int width, real scale) {
MatrixPtr prevOutX = CpuMatrix::create(heightX, width, false, false);
MatrixPtr prevOutY = CpuMatrix::create(heightY, width, false, false);
MatrixPtr grad = CpuMatrix::create(heightX, 1, false, false);
MatrixPtr output = CpuMatrix::create(heightX, 1, false, false);
MatrixPtr prevGradX = CpuMatrix::create(heightX, width, false, false);
MatrixPtr prevGradY = CpuMatrix::create(heightY, width, false, false);
prevOutX->randomizeUniform();
prevOutY->randomizeUniform();
grad->randomizeUniform();
output->randomizeUniform();
prevGradX->randomizeUniform();
prevGradY->randomizeUniform();
MatrixPtr prevOutXGpu = GpuMatrix::create(heightX, width, false, true);
MatrixPtr prevOutYGpu = GpuMatrix::create(heightY, width, false, true);
MatrixPtr gradGpu = GpuMatrix::create(heightX, 1, false, true);
MatrixPtr outputGpu = GpuMatrix::create(heightX, 1, false, true);
MatrixPtr prevGradXGpu = GpuMatrix::create(heightX, width, false, true);
MatrixPtr prevGradYGpu = GpuMatrix::create(heightY, width, false, true);
prevOutXGpu->copyFrom(*prevOutX);
prevOutYGpu->copyFrom(*prevOutY);
gradGpu->copyFrom(*grad);
outputGpu->copyFrom(*output);
prevGradXGpu->copyFrom(*prevGradX);
prevGradYGpu->copyFrom(*prevGradY);
grad->cosSimDerivative(
*output, *prevOutX, *prevOutY, *prevGradX, *prevGradY, scale);
gradGpu->cosSimDerivative(*outputGpu,
*prevOutXGpu,
*prevOutYGpu,
*prevGradXGpu,
*prevGradYGpu,
scale);
TensorCheckErr(*prevGradX, *prevGradXGpu);
TensorCheckErr(*prevGradY, *prevGradYGpu);
}
TEST(Matrix, cosSimDerivate) {
for (auto heightX : {1, 10, 100}) {
for (auto heightY : {1, heightX}) {
for (auto width : {1, 10, 100}) {
for (auto scale : {1.0, 2.0}) {
testCosSimDerivate(heightX, heightY, width, scale);
}
}
}
}
}
void testParamReluBackwardDiff(int height,
int width,
int w_height,
......
......@@ -289,6 +289,7 @@ void mkDir(const char* filename) {
void mkDirRecursively(const char* dir) {
struct stat sb;
if (*dir == 0) return; // empty string
if (!stat(dir, &sb)) return;
mkDirRecursively(path::dirname(dir).c_str());
......
......@@ -201,7 +201,7 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None):
data.load_data_module = load_data_module
data.load_data_object = load_data_object
data.load_data_args = load_data_args
data.async_load_data = True
data.async_load_data = False
return data
define_py_data_sources(
......
......@@ -3677,26 +3677,27 @@ def pad_layer(input,
For example,
.. code-block::
input(2,2,2,3) = [
[ [[1,2,3], [3,4,5]],
[[2,3,5], [1,6,7]] ],
[ [[4,3,1], [1,8,7]],
[[3,8,9], [2,3,5]] ]
]
pad_c=[1,1], pad_h=[0,0], pad_w=[0,0]
output(2,4,2,3) = [
[ [[0,0,0], [0,0,0]],
[[1,2,3], [3,4,5]],
[[2,3,5], [1,6,7]],
[[0,0,0], [0,0,0]] ],
[ [[0,0,0], [0,0,0]],
[[4,3,1], [1,8,7]],
[[3,8,9], [2,3,5]],
[[0,0,0], [0,0,0]] ]
]
.. code-block:: python
input(2,2,2,3) = [
[ [[1,2,3], [3,4,5]],
[[2,3,5], [1,6,7]] ],
[ [[4,3,1], [1,8,7]],
[[3,8,9], [2,3,5]] ]
]
pad_c=[1,1], pad_h=[0,0], pad_w=[0,0]
output(2,4,2,3) = [
[ [[0,0,0], [0,0,0]],
[[1,2,3], [3,4,5]],
[[2,3,5], [1,6,7]],
[[0,0,0], [0,0,0]] ],
[ [[0,0,0], [0,0,0]],
[[4,3,1], [1,8,7]],
[[3,8,9], [2,3,5]],
[[0,0,0], [0,0,0]] ]
]
The simply usage is:
......@@ -4191,13 +4192,7 @@ def block_expand_layer(input,
@wrap_name_default()
@layer_support()
def maxout_layer(input,
groups,
num_channels=None,
size_x=None,
size_y=None,
name=None,
layer_attr=None):
def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None):
"""
A layer to do max out on conv layer output.
- Input: output of a conv layer.
......@@ -4227,12 +4222,6 @@ def maxout_layer(input,
:type num_channels: int|None
:param groups: The group number of input layer.
:type groups: int
:param size_x: conv output width. If None will be set
automatically from previous output.
:type size_x: int|None
:param size_y: conv output height. If None will be set
automatically from previous output.
:type size_y: int|None
:param name: The name of this layer, which can not specify.
:type name: None|basestring.
:param layer_attr: Extra Layer attribute.
......
......@@ -19,7 +19,7 @@ model_config {
data_config {
type: "py2"
files: "train.list"
async_load_data: true
async_load_data: false
for_test: false
load_data_module: "a"
load_data_object: "c"
......@@ -58,7 +58,7 @@ opt_config {
test_data_config {
type: "py2"
files: "test.list"
async_load_data: true
async_load_data: false
for_test: true
load_data_module: "b"
load_data_object: "d"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册