diff --git a/.travis.yml b/.travis.yml index 162bebba091d84b295f929527de9804e65df5a65..5d82d9729b75ef493a0bd03921c453f9a519c8cd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -54,7 +54,9 @@ before_install: fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi - - pip install numpy wheel protobuf sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker + # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python + # protobuf version. + - pip install numpy wheel 'protobuf==3.1' sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker script: - paddle/scripts/travis/main.sh notifications: diff --git a/doc/design/api.md b/doc/design/api.md new file mode 100644 index 0000000000000000000000000000000000000000..8185d2af0ea264a2e7b4e28b9ed05279e4a22014 --- /dev/null +++ b/doc/design/api.md @@ -0,0 +1,262 @@ +# PaddlePaddle Design Doc + +## Ingredients + +As our design principle is starting from the essence: how could we +allow users to express and solve their problems at neural networks. +Some essential concepts that our API have to provide include: + +1. A *topology* is an expression of *layers*. + +1. A layer could be any kind of computation, including *cost*. + +1. Some layers have parameters, some don't. Most costs don't have + parameters. + +1. In some topologies, layers share parameters. For + example, + [the network for training a ranking model](https://github.com/PaddlePaddle/Paddle/issues/1311#issuecomment-279121850). + +1. At programming time, users specify topologies and possible sharing + of parameters. PaddlePaddle can figure out and create parameters + required (and possibly shared) by one or more topologies. + + +## Starting from Examples + +As a summarization +of +[our disucssion](https://github.com/PaddlePaddle/Paddle/issues/1315), +let us present two examples here: + + +### Example 1. Sharing Parameters between Layers + +We use +the +[3-branch ranking](https://github.com/PaddlePaddle/Paddle/issues/1311#issuecomment-279121850) model +in this example. For your convenience, I copy-a-paste the model's +topology as follows: + +``` +A -> f -\ +Q -> f --> cost +B -> f -/ +``` + +The following program trains the topology including the cost, and then +use the sub-network in the trained topology in inference: + +```python +def f(in): + e = paddle.layer.embedding(in, parameter_name="embedding") + o = paddle.layer.softmax(e, parameter_name="semantic") + return o + +# Create 3 topologies (subnets), they share parameters because all +# correspoinding layers have the same parameter names. +fA = f(paddle.layer.data(input_name="A")) +fB = f(paddle.layer.data(input_name="B")) +fQ = f(paddle.layer.data(input_name="Q")) + +topology = paddle.layer.less_than( + paddle.layer.cross_entropy(fA, fQ), + paddle.layer.corss_entropy(fB, fQ)) + +# Derive parameters required in topology and create them in model. +parameters = paddle.parameters.create(topology) + +# Estimate parameters used in topology from data. +paddle.train(topology, parameters, reader=read_ranking_model_data) + +# Inference using fA (or fB or fC, as they share their parameters). +[testA, testB, testQ] = read_ranking_model_data() +print "The sematic-vector of testA: ", paddle.infer(fA, parameters, testA) +``` + + +### Example 2. Sharing Parameters between "Models" + +We use [GAN](https://github.com/PaddlePaddle/book/tree/develop/gan) in +this example. In the following example program, `d0` and `d1` +correspond to the two networks in the following figure: + + + +```python +def G(in): + # over-simplified example as G has only one layers: + return paddle.layer.fc(in, parameter_name="G") + +def D(in); + # again, over-simplified: + return paddle.layer.fc(in, parameter_name="D") + +# Construct the first topology, which contains both D and G. +# By learning this topology, we update parameters of G. +d0 = paddle.layer.should_be_false(D(G(paddle.layer.data()))) + +# Construct a second topology d1, which contains only D. By +# training this topology, we update parameters of D. Note +# that d1 share parameters with d0. +d1 = paddle.layer.should_be_true(D(paddle.layer.data())) + +# Create parameters from a list of multiple topologies (models) for +# the chance to share parameters between these topologies. +parameters = paddle.parameters.create([d0, d1]) + +# Iterative training of GAN. +for ...: + train(d0, parameters, reader=read_from_rng, immutable_parameters={"D"}) + train(d1, parameters, reader=read_from_realistic_images) + +# Use d1 for inference: +print "D thinks a batch of images are realistic ", infer(d1, parameters, read_mnist_images) +``` + + +### Summarization + + +Above two programs reveal some important design concerns: + +1. Users describe a topology as an expression of layers. Every layer + has a *parameter name*. If the users don't specify it explicitly, it's automatically generated as a unique name. By + specifying the parameter name, users can specify the sharing of + parameters between layers and even between topologies. + +1. `paddle.parameters.create` figures out parameters required by one + or more topologies from parameter names of layers. It creates these + parameters and returns a `ParameterSet` object, which is in essence + a map from *parameter names* to *parameters*. + +1. At training and inference time, `paddle.train` and `paddle.infer` + requires both a topology and the parameter set that holds the parameters of that topology. There are some reasons: + + 1. This prevents users from forgetting to call + `paddle.parameters.create`. + 1. `paddle.train` needs to know which parameter set to update. + 1. Users could load another (pre-trained) parameter set and use it + with a topology in `train.infer`. + +1. By specifying the `immutable_parameters` parameter of + `paddle.train`, we can forbid the update of these parameters. + + +## Reader + +Not all programming frameworks allow users to define I/O functions. +An example is Google MapReduce, which can only read from text, +SSTable, and RecordIO files. Hadoop MapReduce allows users to define +readers and writers by deriving from base classes `Reader` and +`Writer`. The former is less flexible but also less error-prone. We +decide to provide the flexibility to users to define their readers. + + +There are some open questions here: + +1. **Should a reader return a Python dictionary?** + +1. **How to map multiple outputs from a reader to multiple data layers?** + +1. **How to easily compose some existing readers to read more data and + feed a topology with more data layers?** + + +## Training + +The recommended way to training a model is to call `paddle.train`, +which simply calls `paddle.trainer.Default`, a global variable of +type `paddle.trainer.SGD`. Equivalently, we can do + +```python +opt = paddle.trainer.SGD(..., paddle.updater.Adam(...)) +opt.train(topology, parameters, reader=read, ...) +``` + +### Updater + +Please be aware that a trainer can accept an updater as its data +member, where an updater is a class derived from +`paddle.trainer.Updater`. This is to make it easier to customize +trainers, as discussed +[here](https://github.com/PaddlePaddle/Paddle/issues/1319). + +### Event Handler + +`paddle.train` and `paddle.trainer.XXX.train` take an optional +parameter `event_handler`, which should be either `None` or a function +that handle some events: + +1. BeginTraining +1. EndTraining +1. BeginIteration +1. EndIteration +1. BeginPass +1. EndPass + +where EndPass is sent if and only if the reader yields +`end_pass=True`. + +An example as follows: + +```python +def event_handler(event): + if ininstance(event, paddle.event.EndIteration): + print paddle.test(...) + +paddle.train(topology, parameters, reader, event_handler) +``` + +If we are writing a PaddlePaddle program in and for iPython/Jypyter, +we can use metaplotlib in the event handler to plot a curve of +cost/error versus iterations, as shown +[here](https://blog.dominodatalab.com/interactive-dashboards-in-jupyter/). + +### Distributed Training + +If users want to do distributed training on a cluster, s/he should +call `paddle.dist_train` and provides access tokens to the cluster as +a parameter. + +For example, if the user has a TLS certificate that allows him to +access a Kubernetes cluster, s/he should be able to call + +```python +paddle.dist_train(model, + trainer=paddle.trainer.SGD(..., + paddle.updater.Adam(...)), + reader=read, + k8s_user="yi", + k8s_token="kube_cluster_tls.pem", + k8s_job="hello", + num_parameter_servers=15) +``` + +The pseudo code if `paddle.dist_train` is as follows: + +```python +def dist_train(topology, parameters, trainer, reader, ...): + if os.getenv("KUBERNETES_SERVICE_HOST") == None: + image_name = k8s_user + '/' + k8s_job + docker_build(image_name) + docker_push() + kube_ctrl_start_job(image_name, k8s_user, k8s_token) + else: + rank = kube_list_containers_in_job_and_return_current_containers_rank() + if rank == 0: + master() + elif rank < 15: + parameter_server() + else: + trainer.train(model, reader=read) +``` + +Please be aware that if a process is running on the Kubernetes +cluster, it will have some environment variables pre-defined. + +If `dist_train` doesn't see these environment variables, it knows +that it's running on users' personal computer, and it should work as a +*launcher*. Otherwise, it knows that it's running on the cluster and +need to figure out its role as either the master, or a trainer, or a +parameter server. diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 40828dd5cc76f4197e6cfbb1121f2eef2c1ac580..6f21b82afdc6cdde785fdd8f13eef17a0fdd6324 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -188,48 +188,6 @@ extern void hl_param_relu_backward_diff(real* grad_o, int width, int height, int partial_sum); -/** - * @brief cos sim forward - * - * @param[out] output output data - * @param[in] input1 input1 data(matrix) - * @param[in] input2 input2 data(matrix or vector) - * @param[in] width matrix width - * @param[in] input1_height input1_height - * @param[in] input2_height input2_height - * @param[in] scale scale factor - */ -extern void hl_cossim(real* output, - real* input1, - real* input2, - int width, - int input1_height, - int input2_height, - real scale); -/** - * @brief cos sim derivate - * - * @param[in] grad output grad - * @param[in] output output data - * @param[in] prevOutX input1 data - * @param[in] prevOutY input2 data - * @param[out] prevGradX input1 grad - * @param[out] prevGradY input2 grad - * @param[in] width matrix width - * @param[in] input1_height input1 height - * @param[in] input2_height input2 height - * @param[in] scale scale factor - */ -extern void hl_cossim_derivative(real* grad, - real* output, - real* prevOutX, - real* prevOutY, - real* prevGradX, - real* prevGradY, - int width, - int input1_height, - int input2_height, - real scale); /** * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel]. diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index a1712d1e4d2a5dc80526b7d7b5ad7bd4f5d8c1ed..f4e6461cdcf198637b2c96fee88d1de2766aaf18 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -74,25 +74,6 @@ inline void hl_param_relu_backward_diff(real* grad_o, int height, int partial_sum) {} -inline void hl_cossim(real* output, - real* input1, - real* input2, - int width, - int input1_height, - int input2_height, - real scale) {} - -inline void hl_cossim_derivative(real* grad, - real* output, - real* prevOutX, - real* prevOutY, - real* prevGradX, - real* prevGradY, - int width, - int input1_height, - int input2_height, - real scale) {} - inline void hl_matrix_add_shared_bias(real* A_d, real* B_d, const int channel, diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index cd23bd31057c5c8cd10173bc5fa5fa67f2d0e422..96c07d9c3b7a37daa9198fd7ea66b7d811600348 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -584,177 +584,6 @@ void hl_param_relu_backward_diff(real* grad_o, CHECK_SYNC("hl_param_relu_backward_diff failed"); } -template -__global__ void KeCosSim(real* output, - real* input1, - real* input2, - int width, - int input1_height, - int input2_height, - real scale) { - const int ty = blockIdx.y; - int tid = threadIdx.x; - - __shared__ real xx[blockSize]; - __shared__ real yy[blockSize]; - __shared__ real xy[blockSize]; - - xx[tid] = 0.0; - yy[tid] = 0.0; - xy[tid] = 0.0; - __syncthreads(); - - input1 += ty * width; - if (input2_height > 1) { - input2 += ty * width; - } - for (int index = tid; index < width; index += blockSize) { - real x = input1[index]; - real y = input2[index]; - xx[tid] += x * x; - yy[tid] += y * y; - xy[tid] += x * y; - } - __syncthreads(); - - for (int s = blockSize / 2; s > 0; s >>= 1) { - if (tid < s) { - xx[tid] += xx[tid + s]; - yy[tid] += yy[tid + s]; - xy[tid] += xy[tid + s]; - } - __syncthreads(); - } - if (tid == 0) { - output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0])); - } -} - -void hl_cossim(real* output, - real* input1, - real* input2, - int width, - int input1_height, - int input2_height, - real scale) { - CHECK_NOTNULL(output); - CHECK_NOTNULL(input1); - CHECK_NOTNULL(input2); - const int blockSize = 256; - dim3 threads(blockSize, 1); - dim3 grid(1, input1_height); - - KeCosSim<<>> - (output, input1, input2, width, input1_height, input2_height, scale); - CHECK_SYNC("hl_cossim failed"); -} - -template -__global__ void KeCosSimDerivative(real* grad, - real* output, - real* prevOutX, - real* prevOutY, - real* prevGradX, - real* prevGradY, - int width, - int input1_height, - int input2_height, - real scale) { - const int ty = blockIdx.y; - int tid = threadIdx.x; - - __shared__ real xx[blockSize]; - __shared__ real yy[blockSize]; - __shared__ real xy[blockSize]; - - xx[tid] = 0.0; - yy[tid] = 0.0; - xy[tid] = 0.0; - __syncthreads(); - - prevOutX += ty * width; - prevGradX += ty * width; - if (input2_height > 1) { - prevOutY += ty * width; - prevGradY += ty * width; - } - for (int index = tid; index < width; index += blockSize) { - real x = prevOutX[index]; - real y = prevOutY[index]; - xx[tid] += x * x; - yy[tid] += y * y; - xy[tid] += x * y; - } - __syncthreads(); - - for (int s = blockSize / 2; s > 0; s >>= 1) { - if (tid < s) { - xx[tid] += xx[tid + s]; - yy[tid] += yy[tid + s]; - xy[tid] += xy[tid + s]; - } - __syncthreads(); - } - if (xy[0] == 0) { - real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0])); - for (int index = tid; index < width; index += blockSize) { - prevGradX[index] += - scale * grad[ty] * prevOutY[index] * reciprocal; - if (input2_height > 1) { - prevGradY[index] += - scale * grad[ty] * prevOutX[index] * reciprocal; - } else { - paddle::paddleAtomicAdd(prevGradY + index, - scale * grad[ty] * prevOutX[index] * reciprocal); - } - } - } else { - real reciprocalXY = 1.0 / xy[0]; - real reciprocalSquareSumX = 1.0 / xx[0]; - real reciprocalSquareSumY = 1.0 / yy[0]; - for (int index = tid; index < width; index += blockSize) { - prevGradX[index] += output[ty] * grad[ty] * - (prevOutY[index] * reciprocalXY - - prevOutX[index] * reciprocalSquareSumX); - if (input2_height > 1) { - prevGradY[index] += output[ty] * grad[ty] * - (prevOutX[index] * reciprocalXY - - prevOutY[index] * reciprocalSquareSumY); - } else { - paddle::paddleAtomicAdd(prevGradY + index, output[ty] * grad[ty] * - (prevOutX[index] * reciprocalXY - - prevOutY[index] * reciprocalSquareSumY)); - } - } - } -} - - -void hl_cossim_derivative(real* grad, - real* output, - real* prevOutX, - real* prevOutY, - real* prevGradX, - real* prevGradY, - int width, - int input1_height, - int input2_height, - real scale) { - CHECK_NOTNULL(grad); - CHECK_NOTNULL(output); - CHECK_NOTNULL(prevOutX); - CHECK_NOTNULL(prevOutY); - CHECK_NOTNULL(prevGradX); - CHECK_NOTNULL(prevGradY); - const int blockSize = 256; - dim3 threads(blockSize, 1); - dim3 grid(1, input1_height); - KeCosSimDerivative<<>> - (grad, output, prevOutX, prevOutY, prevGradX, prevGradY, width, - input1_height, input2_height, scale); - CHECK_SYNC("hl_cossim_derivate failed"); -} - __global__ void KeMatrixAddSharedBias(real* A, real* B, const int channel, diff --git a/paddle/function/BufferArg.h b/paddle/function/BufferArg.h index 349b21e7e64064804c5d0ee26e82698925832c35..0dc7792f646457c22ee4791f18814afaa3809f7b 100644 --- a/paddle/function/BufferArg.h +++ b/paddle/function/BufferArg.h @@ -190,7 +190,7 @@ public: : BufferArg(VALUE_TYPE_INT32, shape, argType) { bufferType_ = TENSOR_SEQUENCE_ID; CHECK_EQ(shape_.ndims(), 1UL); - CHECK_GT(shape_[0], 1UL); + CHECK_GE(shape_[0], 1UL); numSeqs_ = shape_[0] - 1; } @@ -226,7 +226,8 @@ public: SequenceArg(ValueType valueType, const TensorShape& shape, ArgType argType = UNSPECIFIED) - : BufferArg(valueType, shape, argType), startPositions_(TensorShape()) { + : BufferArg(valueType, shape, argType), + startPositions_(TensorShape({shape[0]})) { bufferType_ = TENSOR_SEQUENCE_DATA; } diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index fae3b7b20a70b56dc44ea2df637281afe01a7e5a..1522510e8bb9816cb468fcf406e22560163950cc 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -27,6 +27,7 @@ if(WITH_TESTING) add_simple_unittest(ContextProjectionOpTest) add_simple_unittest(PadOpTest) add_simple_unittest(MulOpTest) + add_simple_unittest(CosSimOpTest) endif() endif() diff --git a/paddle/function/ContextProjectionOp.cpp b/paddle/function/ContextProjectionOp.cpp index 6cd4e4abee8fccf3a4745b0bfc6701df4ddfa5c0..b87750b74247bd0eb822340bc5a85d41b86ecec2 100644 --- a/paddle/function/ContextProjectionOp.cpp +++ b/paddle/function/ContextProjectionOp.cpp @@ -108,26 +108,23 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK(1 == inputs.size() || 2 == inputs.size()); - CHECK_EQ((size_t)1, outputs.size()); + CHECK(1UL == inputs.size() || 2UL == inputs.size()); + CHECK_EQ(1UL, outputs.size()); CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) << "SequenceArg required here"; const auto val_seqs = dynamic_cast(inputs[0]); auto out_seq = dynamic_cast(outputs[0]); CHECK(out_seq.data() && val_seqs.data() && val_seqs.getSequenceId().data()); - CHECK_EQ(out_seq.shape().ndims(), (size_t)2); - CHECK_EQ(val_seqs.shape().ndims(), (size_t)2); - CHECK_EQ(val_seqs.getSequenceId().shape().ndims(), (size_t)1); - if (2 == inputs.size()) { - CHECK_EQ(inputs[1].shape().ndims(), (size_t)2); - } + CHECK_EQ(out_seq.shape().ndims(), 2UL); + CHECK_EQ(val_seqs.shape().ndims(), 2UL); /// dim of output = dim of input * context_length CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_); /// input and output has the same batch_size CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]); - /// dim of input == dim of weight - if (2 == inputs.size()) { + if (2UL == inputs.size()) { + CHECK_EQ(inputs[1].shape().ndims(), 2UL); + /// dim of input == dim of weight CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]); } @@ -135,10 +132,11 @@ public: auto out_mat = out_seq.matrix(); const auto in_mat = val_seqs.matrix(); const auto w_mat = - (2 == inputs.size()) + (2UL == inputs.size() && inputs[1].data()) ? inputs[1].matrix() : typename Tensor::Matrix(nullptr, 0, 0); const auto seq_vec = val_seqs.getSequenceId().vector(); + ContextProjectionForward(out_mat, in_mat, w_mat, @@ -235,36 +233,40 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)1, inputs.size()); - CHECK_EQ((size_t)2, outputs.size()); + CHECK_EQ(1UL, inputs.size()); + CHECK(1UL == outputs.size() || 2UL == outputs.size()); CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) << "SequenceArg required here"; const auto in_seq = dynamic_cast(inputs[0]); auto out_seq = dynamic_cast(outputs[0]); CHECK(in_seq.data() && in_seq.getSequenceId().data()); - CHECK_EQ(in_seq.shape().ndims(), (size_t)2); - CHECK_EQ(in_seq.getSequenceId().shape().ndims(), (size_t)1); - CHECK_EQ(out_seq.shape().ndims(), (size_t)2); - CHECK_EQ(out_seq.getSequenceId().shape().ndims(), (size_t)1); - CHECK_EQ(outputs[1].shape().ndims(), (size_t)2); + CHECK_EQ(in_seq.shape().ndims(), 2UL); + CHECK_EQ(out_seq.shape().ndims(), 2UL); + CHECK_EQ(out_seq.getSequenceId().shape().ndims(), 1UL); - /// dim of input grad == dim of weight - CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]); /// input and output grad has the same batch_size CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]); /// dim of output grad = dim of input grad * context_length CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_); CHECK_EQ(out_seq.getArgType(), ADD_TO); - CHECK_EQ(outputs[1].getArgType(), ADD_TO); + + if (2UL == outputs.size()) { + CHECK_EQ(outputs[1].shape().ndims(), 2UL); + /// dim of input grad == dim of weight + CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]); + CHECK_EQ(outputs[1].getArgType(), ADD_TO); + } const auto seq_vec = in_seq.getSequenceId().vector(); const auto out_grad_mat = in_seq.matrix(); auto in_grad_mat = !out_seq.data() ? typename Tensor::Matrix(nullptr, 0, 0) : out_seq.matrix(); - auto w_grad_mat = !outputs[1].data() - ? typename Tensor::Matrix(nullptr, 0, 0) - : outputs[1].matrix(); + auto w_grad_mat = + (2UL == outputs.size() && outputs[1].data()) + ? outputs[1].matrix() + : typename Tensor::Matrix(nullptr, 0, 0); + ContextProjectionBackward(out_grad_mat, in_grad_mat, w_grad_mat, @@ -304,17 +306,17 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ(1, static_cast(inputs.size())); - CHECK_EQ(1, static_cast(outputs.size())); + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) << "SequenceArg required here"; const auto in_seq = dynamic_cast(inputs[0]); const auto out_seq = dynamic_cast(outputs[0]); CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceId().data()); - CHECK_EQ(static_cast(out_seq.shape().ndims()), 2); - CHECK_EQ(static_cast(in_seq.shape().ndims()), 2); - CHECK_EQ(static_cast(in_seq.getSequenceId().shape().ndims()), 1); + CHECK_EQ(out_seq.shape().ndims(), 2UL); + CHECK_EQ(in_seq.shape().ndims(), 2UL); + CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL); /// output layer grad dim == input layer grad dim * context_length_ CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_); /// input and output has the same batch_size @@ -355,14 +357,14 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ(1, static_cast(inputs.size())); - CHECK_EQ(1, static_cast(outputs.size())); + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here"; const auto in_seq = dynamic_cast(inputs[0]); CHECK(in_seq.data() && in_seq.getSequenceId().data() && outputs[0].data()); - CHECK_EQ(static_cast(outputs[0].shape().ndims()), 2); - CHECK_EQ(static_cast(in_seq.shape().ndims()), 2); - CHECK_EQ(static_cast(in_seq.getSequenceId().shape().ndims()), 1); + CHECK_EQ(outputs[0].shape().ndims(), 2UL); + CHECK_EQ(in_seq.shape().ndims(), 2UL); + CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL); CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]); /// output layer grad dim == weight dim * context_length_ CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_); diff --git a/paddle/function/ContextProjectionOp.h b/paddle/function/ContextProjectionOp.h index 2bdd47e4e9b02483c2c5af82bf00c4e55d68f93e..6f7d936379a5378e6fd85dd86618d1b6094bd14f 100644 --- a/paddle/function/ContextProjectionOp.h +++ b/paddle/function/ContextProjectionOp.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - #include "Function.h" namespace paddle { diff --git a/paddle/function/ContextProjectionOpTest.cpp b/paddle/function/ContextProjectionOpTest.cpp index c9db2ff8008e0bb0fa04370fb7b3ecd7641d2062..0f5d6a848d406d14984a0b6edad8192dab42e88b 100644 --- a/paddle/function/ContextProjectionOpTest.cpp +++ b/paddle/function/ContextProjectionOpTest.cpp @@ -28,55 +28,26 @@ void testMatrixProjectionForward(int context_start, std::max(0, (int)(context_start + context_length - 1)); if (pad == 0) is_padding = false; - FunctionCompare compare("ContextProjectionForward", - FuncConfig() - .set("context_length", context_length) - .set("context_start", context_start) - .set("begin_pad", std::max(0, -context_start))); - - CpuMatrix cpu_in(batch_size, input_dim); - cpu_in.randomizeUniform(); - GpuMatrix gpu_in(batch_size, input_dim); - gpu_in.copyFrom(cpu_in); - auto cpu_weight = - is_padding ? std::make_shared(pad, input_dim) : nullptr; - auto gpu_weight = - is_padding ? std::make_shared(pad, input_dim) : nullptr; - if (is_padding) { - cpu_weight->randomizeUniform(); - gpu_weight->copyFrom(*cpu_weight); + FunctionCompare test("ContextProjectionForward", + FuncConfig() + .set("context_length", context_length) + .set("context_start", context_start) + .set("begin_pad", std::max(0, -context_start))); + + // prepare input arguments + test.addSequence(SequenceIdArg(TensorShape{batch_size})); + test.addInputs( + SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim})); + if (is_padding) { // weight + test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{pad, input_dim})); } - IVectorPtr cpu_seq; - generateSequenceStartPositions(batch_size, cpu_seq); - IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true); - gpu_seq->copyFrom(*cpu_seq); - - CpuMatrix cpu_out(batch_size, input_dim * context_length); - GpuMatrix gpu_out(batch_size, input_dim * context_length); - cpu_out.randomizeUniform(); - gpu_out.copyFrom(cpu_out); - - BufferArgs cpu_inputs; - BufferArgs cpu_outputs; - cpu_inputs.addArg(cpu_in, *cpu_seq); - if (cpu_weight) { - cpu_inputs.addArg(*cpu_weight, *cpu_seq); - } - cpu_outputs.addArg(cpu_out, *cpu_seq, ADD_TO); - - compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs); + test.addOutputs( + SequenceArg(VALUE_TYPE_FLOAT, + TensorShape{batch_size, input_dim * context_length}), + ADD_TO); - BufferArgs gpu_inputs; - BufferArgs gpu_outputs; - gpu_inputs.addArg(gpu_in, *gpu_seq); - if (gpu_weight) { - gpu_inputs.addArg(*gpu_weight, *gpu_seq); - } - gpu_outputs.addArg(gpu_out, *gpu_seq, ADD_TO); - - compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs); - - autotest::TensorCheckEqual(cpu_out, gpu_out); + // run Function + test.run(); } void testMatrixProjectionBackward(int context_start, @@ -88,63 +59,31 @@ void testMatrixProjectionBackward(int context_start, std::max(0, (int)(context_start + context_length - 1)); if (pad == 0) is_padding = false; - FunctionCompare compare("ContextProjectionBackward", - FuncConfig() - .set("context_length", context_length) - .set("context_start", context_start) - .set("begin_pad", std::max(0, -context_start)) - .set("is_padding", is_padding) - .set("total_pad", pad)); - - CpuMatrix cpu_in_grad(batch_size, input_dim); - cpu_in_grad.randomizeUniform(); - GpuMatrix gpu_in_grad(batch_size, input_dim); - gpu_in_grad.copyFrom(cpu_in_grad); - - CpuMatrix cpu_out_grad(batch_size, input_dim * context_length); - cpu_out_grad.randomizeUniform(); - GpuMatrix gpu_out_grad(batch_size, input_dim * context_length); - gpu_out_grad.copyFrom(cpu_out_grad); - - IVectorPtr cpu_seq; - generateSequenceStartPositions(batch_size, cpu_seq); - IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true); - gpu_seq->copyFrom(*cpu_seq); - - auto cpu_w_grad = - is_padding ? std::make_shared(pad, input_dim) : nullptr; - auto gpu_w_grad = - is_padding ? std::make_shared(pad, input_dim) : nullptr; - if (is_padding) { - cpu_w_grad->randomizeUniform(); - gpu_w_grad->copyFrom(*cpu_w_grad); + FunctionCompare test("ContextProjectionBackward", + FuncConfig() + .set("context_length", context_length) + .set("context_start", context_start) + .set("begin_pad", std::max(0, -context_start)) + .set("is_padding", is_padding) + .set("total_pad", pad)); + + // prepare input arguments + test.addSequence(SequenceIdArg(TensorShape{batch_size})); + test.addInputs(SequenceArg( + VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim * context_length})); + test.addOutputs( + SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim}), + ADD_TO); + if (is_padding) { // weight + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{pad, input_dim}), + ADD_TO); } - BufferArgs cpu_inputs; - BufferArgs cpu_outputs; - cpu_inputs.addArg(cpu_out_grad, *cpu_seq); - cpu_outputs.addArg(cpu_in_grad, *cpu_seq, ADD_TO); - cpu_outputs.addArg( - cpu_w_grad ? *cpu_w_grad : CpuMatrix(nullptr, 0, input_dim), ADD_TO); - - compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs); - - BufferArgs gpu_inputs; - BufferArgs gpu_outputs; - gpu_inputs.addArg(gpu_out_grad, *gpu_seq); - gpu_outputs.addArg(gpu_in_grad, *gpu_seq, ADD_TO); - gpu_outputs.addArg( - gpu_w_grad ? *gpu_w_grad : GpuMatrix(nullptr, 0, input_dim), ADD_TO); - - compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs); - - autotest::TensorCheckErr(cpu_in_grad, gpu_in_grad); - if (is_padding) { - autotest::TensorCheckErr(*cpu_w_grad, *gpu_w_grad); - } + // run Function + test.run(); } -TEST(ContextProjection, projection) { +TEST(ContextProjection, Projection) { for (auto context_start : {-5, -3, -1, 0, 3}) { for (auto context_length : {1, 2, 5, 7}) { for (auto trainable_padding : {false, true}) { diff --git a/paddle/function/CosSimOp.cpp b/paddle/function/CosSimOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ece7b2dfedaf460741c97b5a700eb632d85cabc --- /dev/null +++ b/paddle/function/CosSimOp.cpp @@ -0,0 +1,240 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CosSimOp.h" +#include "paddle/math/Matrix.h" +#include "paddle/math/Vector.h" + +namespace paddle { +/** + * Cosine Similarity for CpuMatrix + * + * \param out_mat, output value, size: nSamples * 1. + * \param in1_mat, input value 1, size: nSamples * dim. + * \param in2_mat, input value 2, size: n2 * dim (n2 == 1 or n2 == nSamples). + * \param scale, default 1.0 + * + */ +template <> +void CosSimForward(CpuMatrix& out_mat, + const CpuMatrix& in1_mat, + const CpuMatrix& in2_mat, + real scale) { + CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData()); + size_t num_samples = out_mat.getHeight(); + size_t dim = in1_mat.getWidth(); + /// column vector [nSamples, 1] + real* out = out_mat.getData(); + const real* x = in1_mat.getData(); + const real* y = in2_mat.getData(); + + /// in2 might only have one row or full rows + CHECK(in2_mat.getHeight() == 1LU || in2_mat.getHeight() == num_samples); + size_t inc = (in2_mat.getHeight() == 1LU) ? 0 : dim; + for (size_t i = 0; i < num_samples; ++i, x += dim, y += inc) { + real square_sum_x = 0; + real square_sum_y = 0; + real xy = 0; + for (size_t j = 0; j < dim; ++j) { + square_sum_x += x[j] * x[j]; + square_sum_y += y[j] * y[j]; + xy += x[j] * y[j]; + } + CHECK(square_sum_x > 0 && square_sum_y > 0); + out[i] = scale * xy / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y)); + } +} + +/** + * Cosine Similarity + * for each row i, + * out[i] = scale * cos(input1[i], input2[i]) + * = scale * /sqrt(|input1[i]|^2 * |input2[i]|^2) + * when input2 only has one row, then for each row i, + * out[i] = cos(input1[i], input2[0]) + * + * \param inputs[0] input matrix 1, size: nSamples * dim. + * \param inputs[1] input matrix 2, size: n2 * dim (n2 == 1 or n2 == nSamples). + * \param outputs[0] output matrix, size : nSamples * 1. + */ + +template +class CosSimForwardFunc : public FunctionBase { + void init(const FuncConfig& config) override { + scale_ = config.get("scale"); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(inputs.size(), 2UL); + CHECK_EQ(outputs.size(), 1UL); + + CHECK_EQ(inputs[0].shape().ndims(), 2UL); + CHECK_EQ(inputs[1].shape().ndims(), 2UL); + CHECK_EQ(outputs[0].shape().ndims(), 2UL); + + CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]); + CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]); + CHECK_EQ(outputs[0].shape()[1], 1UL); + + CHECK(outputs[0].data() && inputs[0].data() && inputs[1].data()); + + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + auto out_mat = outputs[0].matrix(); + const auto in1_mat = inputs[0].matrix(); + const auto in2_mat = inputs[1].matrix(); + + CosSimForward(out_mat, in1_mat, in2_mat, scale_); + } + +private: + real scale_; +}; + +/** + * Cosine Similarity Derivative for CpuMatrix + * + * \param in1_grad forward input grad 1, size: nSamples * dim. + * \param in2_grad forward input grad 2, + * size: n2 * dim (n2 == 1 or n2 == nSamples). + * + * \param out_grad backward loss output grad, size : nSamples * 1. + * \param out_val forward output value, size: nSamples * 1. + * \param in1_val forward input value 1, size: nSamples * dim. + * \param in2_val forward input value 2, + * size: n2 * dim (n2 == 1 or n2 == nSamples). + * \param scale, default 1.0 + */ +template <> +void CosSimBackward(const CpuMatrix& out_grad, + const CpuMatrix& out_val, + const CpuMatrix& in1_val, + const CpuMatrix& in2_val, + CpuMatrix& in1_grad, + CpuMatrix& in2_grad, + real scale) { + CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() && + in2_val.getData() && in1_grad.getData() && in2_grad.getData()); + CHECK_EQ(out_val.useGpu_, false) << "Matrix type are GPU, CPU required"; + + const real* grad = out_grad.getData(); + const real* out = out_val.getData(); + const real* prev_out_x = in1_val.getData(); + const real* prev_out_y = in2_val.getData(); + real* prev_grad_x = in1_grad.getData(); + real* prev_grad_y = in2_grad.getData(); + + size_t num_samples = out_grad.getHeight(); + size_t dim = in1_val.getWidth(); + CHECK_EQ(in2_val.getHeight(), in2_grad.getHeight()); + CHECK(in2_val.getHeight() == 1LU || in2_val.getHeight() == num_samples); + size_t inc = (in2_val.getHeight() == 1LU) ? 0 : dim; + for (size_t i = 0; i < num_samples; ++i, + prev_out_x += dim, + prev_out_y += inc, + prev_grad_x += dim, + prev_grad_y += inc) { + real square_sum_x = 0; + real square_sum_y = 0; + real xy = 0; + for (size_t j = 0; j < dim; ++j) { + square_sum_x += prev_out_x[j] * prev_out_x[j]; + square_sum_y += prev_out_y[j] * prev_out_y[j]; + xy += prev_out_x[j] * prev_out_y[j]; + } + CHECK(square_sum_x > 0 && square_sum_y > 0); + if (xy == 0) { + real reciprocal = + 1.0f / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y)); + for (size_t j = 0; j < dim; ++j) { + prev_grad_x[j] += scale * grad[i] * prev_out_y[j] * reciprocal; + prev_grad_y[j] += scale * grad[i] * prev_out_x[j] * reciprocal; + } + } else { + real reciprocal_xy = 1.0f / xy; + real reciprocal_square_sum_x = 1.0f / square_sum_x; + real reciprocal_square_sum_y = 1.0f / square_sum_y; + for (size_t j = 0; j < dim; ++j) { + prev_grad_x[j] += + out[i] * grad[i] * (prev_out_y[j] * reciprocal_xy - + prev_out_x[j] * reciprocal_square_sum_x); + prev_grad_y[j] += + out[i] * grad[i] * (prev_out_x[j] * reciprocal_xy - + prev_out_y[j] * reciprocal_square_sum_y); + } + } + } +} + +/** + * Cosine Similarity backward Derivative + * + * \param outputs[0] forward input grad 1, size: nSamples * dim. + * \param outputs[1] forward input grad 2, + * size: n2 * dim (n2 == 1 or n2 == nSamples). + * + * \param inputs[0] backward loss output grad, size : nSamples * 1. + * \param inputs[1] forward output value, size: nSamples * 1. + * \param inputs[2] forward input value 1, size: nSamples * dim. + * \param inputs[3] forward input value 2, + * size: n2 * dim (n2 == 1 or n2 == nSamples). + */ +template +class CosSimBackwardFunc : public FunctionBase { + void init(const FuncConfig& config) override { + scale_ = config.get("scale"); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(inputs.size(), 4UL); + CHECK_EQ(outputs.size(), 2UL); + /// dim of out_grad and out_val == 1, column vector + CHECK_EQ(inputs[0].shape()[1], 1UL); + CHECK_EQ(inputs[1].shape()[1], 1UL); + /// nSamples of out_grad == out_val == in_val1 == in_grad1 + CHECK_EQ(inputs[1].shape()[0], inputs[0].shape()[0]); + CHECK_EQ(inputs[0].shape()[0], inputs[0].shape()[0]); + CHECK_EQ(outputs[0].shape()[0], inputs[0].shape()[0]); + /// dim of in1_val1 == in_val2 == in_grad1 == in_grad2 + CHECK_EQ(inputs[3].shape()[1], inputs[2].shape()[1]); + CHECK_EQ(outputs[0].shape()[1], inputs[2].shape()[1]); + CHECK_EQ(outputs[1].shape()[1], inputs[2].shape()[1]); + + CHECK(inputs[0].data() && inputs[1].data() && inputs[2].data() && + inputs[3].data() && outputs[0].data() && outputs[1].data()); + + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + CHECK_EQ(outputs[1].getArgType(), ADD_TO); + + const auto out_grad = inputs[0].matrix(); + const auto out_val = inputs[1].matrix(); + const auto in1_val = inputs[2].matrix(); + const auto in2_val = inputs[3].matrix(); + auto in1_grad = outputs[0].matrix(); + auto in2_grad = outputs[1].matrix(); + + CosSimBackward( + out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_); + } + +private: + real scale_; +}; + +REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc); +REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc); +REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc); +#endif +} // namespace paddle diff --git a/paddle/function/CosSimOp.h b/paddle/function/CosSimOp.h new file mode 100644 index 0000000000000000000000000000000000000000..be73064e6375bf1e6c6a7ca6de52e9b9b755880b --- /dev/null +++ b/paddle/function/CosSimOp.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief Cosine Similarity Forward. + * for each row i, + * out[i] = scale * cos(in1[i], in2[i]) + * = scale * \sum_j (in1[i][j] * in2[i][j]) / + * sqrt(sum_j (in1[i][j]^2) * sum_j (in2[i][j])^2) + * + * \param[out] output output value. + * \param[in] intput1 input value. + * \param[in] intput2 input value. + * \param[in] scale default 1.0. + * + */ +template +void CosSimForward(typename Tensor::Matrix& output, + const typename Tensor::Matrix& input1, + const typename Tensor::Matrix& input2, + real scale); + +/** + * \brief Cosine Similarity BackWard for Derivative. + * + * \param[in] output grad backward loss output grad. + * \param[in] output val forward-output value. + * \param[in] input val1 forward input value 1. + * \param[in] input val2 forward input value 2. + * \param[in/out] input grad forward input grad 1. + * \param[in/out] input grad forward input grad 2. + * \param[in] scale default 1.0. + * + */ +template +void CosSimBackward(const typename Tensor::Matrix& out_grad, + const typename Tensor::Matrix& out_value, + const typename Tensor::Matrix& in1_value, + const typename Tensor::Matrix& in2_value, + typename Tensor::Matrix& in1_grad, + typename Tensor::Matrix& in2_grad, + real scale); + +} // namespace paddle diff --git a/paddle/function/CosSimOpGpu.cu b/paddle/function/CosSimOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..1dd733674fa0542c76070955ec63e008b083c7d2 --- /dev/null +++ b/paddle/function/CosSimOpGpu.cu @@ -0,0 +1,241 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "hl_device_functions.cuh" +#include "CosSimOp.h" + +namespace paddle { + +template +__global__ void KeCosSim(real* output, + const real* input1, + const real* input2, + int width, + int input1_height, + int input2_height, + real scale) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + __shared__ real xx[block_size]; + __shared__ real yy[block_size]; + __shared__ real xy[block_size]; + + xx[tid] = 0.0; + yy[tid] = 0.0; + xy[tid] = 0.0; + __syncthreads(); + + input1 += ty * width; + if (input2_height > 1) { + input2 += ty * width; + } + for (int index = tid; index < width; index += block_size) { + real x = input1[index]; + real y = input2[index]; + xx[tid] += x * x; + yy[tid] += y * y; + xy[tid] += x * y; + } + __syncthreads(); + + for (int s = block_size / 2; s > 0; s >>= 1) { + if (tid < s) { + xx[tid] += xx[tid + s]; + yy[tid] += yy[tid + s]; + xy[tid] += xy[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0])); + } +} + +void hlCossim(real* output, + const real* input1, + const real* input2, + size_t width, + size_t input1_height, + size_t input2_height, + real scale) { + CHECK_NOTNULL(output); + CHECK_NOTNULL(input1); + CHECK_NOTNULL(input2); + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, input1_height); + + KeCosSim<<>> + (output, input1, input2, width, input1_height, input2_height, scale); + CHECK_SYNC("hlCossim failed"); +} + +template <> +void CosSimForward(GpuMatrix& out_mat, + const GpuMatrix& in1_mat, + const GpuMatrix& in2_mat, + real scale) { + CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData()); + CHECK(in1_mat.useGpu_ == true && in2_mat.useGpu_ == true) + << "Matrix type are not GPU"; + + size_t num_samples = out_mat.getHeight(); + size_t dim = in1_mat.getWidth(); + real* out = out_mat.getData(); + const real* x = in1_mat.getData(); + const real* y = in2_mat.getData(); + hlCossim(out, x, y, dim, in1_mat.getHeight(), in2_mat.getHeight(), scale); +} + +template +__global__ void KeCosSimDerivative(const real* grad, + const real* output, + const real* prev_out_x, + const real* prev_out_y, + real* prev_grad_x, + real* prev_grad_y, + size_t width, + size_t input1_height, + size_t input2_height, + real scale) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + __shared__ real xx[block_size]; + __shared__ real yy[block_size]; + __shared__ real xy[block_size]; + + xx[tid] = 0.0; + yy[tid] = 0.0; + xy[tid] = 0.0; + __syncthreads(); + + prev_out_x += ty * width; + prev_grad_x += ty * width; + if (input2_height > 1) { + prev_out_y += ty * width; + prev_grad_y += ty * width; + } + for (int index = tid; index < width; index += block_size) { + real x = prev_out_x[index]; + real y = prev_out_y[index]; + xx[tid] += x * x; + yy[tid] += y * y; + xy[tid] += x * y; + } + __syncthreads(); + + for (int s = block_size / 2; s > 0; s >>= 1) { + if (tid < s) { + xx[tid] += xx[tid + s]; + yy[tid] += yy[tid + s]; + xy[tid] += xy[tid + s]; + } + __syncthreads(); + } + if (xy[0] == 0) { + real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0])); + for (int index = tid; index < width; index += block_size) { + prev_grad_x[index] += + scale * grad[ty] * prev_out_y[index] * reciprocal; + if (input2_height > 1) { + prev_grad_y[index] += + scale * grad[ty] * prev_out_x[index] * reciprocal; + } else { + paddle::paddleAtomicAdd(prev_grad_y + index, + scale * grad[ty] * prev_out_x[index] * reciprocal); + } + } + } else { + real reciprocalXY = 1.0 / xy[0]; + real reciprocalSquareSumX = 1.0 / xx[0]; + real reciprocalSquareSumY = 1.0 / yy[0]; + for (int index = tid; index < width; index += block_size) { + prev_grad_x[index] += output[ty] * grad[ty] * + (prev_out_y[index] * reciprocalXY - + prev_out_x[index] * reciprocalSquareSumX); + if (input2_height > 1) { + prev_grad_y[index] += output[ty] * grad[ty] * + (prev_out_x[index] * reciprocalXY - + prev_out_y[index] * reciprocalSquareSumY); + } else { + paddle::paddleAtomicAdd(prev_grad_y + index, output[ty] * grad[ty] * + (prev_out_x[index] * reciprocalXY - + prev_out_y[index] * reciprocalSquareSumY)); + } + } + } +} + +void hlCossimDerivative(const real* grad, + const real* output, + const real* prev_out_x, + const real* prev_out_y, + real* prev_grad_x, + real* prev_grad_y, + size_t width, + size_t input1_height, + size_t input2_height, + real scale) { + CHECK_NOTNULL(grad); + CHECK_NOTNULL(output); + CHECK_NOTNULL(prev_out_x); + CHECK_NOTNULL(prev_out_y); + CHECK_NOTNULL(prev_grad_x); + CHECK_NOTNULL(prev_grad_y); + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, input1_height); + KeCosSimDerivative<<>> + (grad, output, prev_out_x, prev_out_y, prev_grad_x, prev_grad_y, width, + input1_height, input2_height, scale); + CHECK_SYNC("hlCossimDerivate failed"); +} + +template <> +void CosSimBackward(const GpuMatrix& out_grad, + const GpuMatrix& out_val, + const GpuMatrix& in1_val, + const GpuMatrix& in2_val, + GpuMatrix& in1_grad, + GpuMatrix& in2_grad, + real scale) { + CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() && + in2_val.getData() && in1_grad.getData() && in2_grad.getData()); + CHECK(out_grad.useGpu_ && out_val.useGpu_ && in1_val.useGpu_ + && in2_val.useGpu_ && in1_grad.useGpu_ && in2_grad.useGpu_) + << "Matrix types are not equally GPU"; + + size_t dim = in1_val.getWidth(); + const real* grad = out_grad.getData(); + const real* out = out_val.getData(); + const real* prev_out_x = in1_val.getData(); + const real* prev_out_y = in2_val.getData(); + real* prev_grad_x = in1_grad.getData(); + real* prev_grad_y = in2_grad.getData(); + hlCossimDerivative(grad, + out, + prev_out_x, + prev_out_y, + prev_grad_x, + prev_grad_y, + dim, + in1_val.getHeight(), + in2_val.getHeight(), + scale); +} + +} // namespace paddle diff --git a/paddle/function/CosSimOpTest.cpp b/paddle/function/CosSimOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48c815f027161b48c17ce654ab819156fd856199 --- /dev/null +++ b/paddle/function/CosSimOpTest.cpp @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" +#include "paddle/math/Matrix.h" + +using namespace paddle; // NOLINT + +void testCosSimForward(size_t height_x, + size_t height_y, + size_t width, + real scale) { + FunctionCompare test("CosSimForward", FuncConfig().set("scale", scale)); + // prepare input arguments + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width})); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}), + ASSIGN_TO); + // run Function + test.run(); +} + +void testCosSimBackward(size_t height_x, + size_t height_y, + size_t width, + real scale) { + FunctionCompare test("CosSimBackward", FuncConfig().set("scale", scale)); + // prepare input arguments + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width})); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}), + ADD_TO); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}), + ADD_TO); + // run Function + test.run(); +} + +TEST(Matrix, cosSim) { + for (auto height_x : {10, 100, 1000}) { + for (auto height_y : {1, height_x}) { + for (auto width : {10, 100, 1000}) { + for (auto scale : {1.0, 2.0}) { + testCosSimForward(height_x, height_y, width, scale); + testCosSimBackward(height_x, height_y, width, scale); + } + } + } + } +} diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h index 00f59f97d4c8c1076abe00866b786615a9801a5d..0cfafdb27f55a3e6617d31a968d2a05fc77f5b46 100644 --- a/paddle/function/FunctionTest.h +++ b/paddle/function/FunctionTest.h @@ -69,6 +69,54 @@ public: gpuMemory_.back()->getBuf(), input.valueType(), input.shape())); } + // assume one copy of sequence is shared by different SequenceArgs + void addSequence(const SequenceIdArg& input) { + CHECK_EQ(input.shape().ndims(), 1UL); + size_t batchSize = input.shape()[0]; + size_t numSeqs = batchSize / 10 + 1; + size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32); + cpuMemory_.emplace_back(std::make_shared(sizeId)); + gpuMemory_.emplace_back(std::make_shared(sizeId)); + cpuSeq_ = std::make_shared(cpuMemory_.back()->getBuf(), + TensorShape{numSeqs + 1}); + gpuSeq_ = std::make_shared(gpuMemory_.back()->getBuf(), + TensorShape{numSeqs + 1}); + /// init sequence Id + initArg(*cpuSeq_, batchSize); + + // todo(tianbing), delete it + CHECK_EQ(cpuSeq_->shape().getElements(), cpuSeq_->numSeqs() + 1); + + CpuIVector cpuSeq(cpuSeq_->shape().getElements(), (int*)cpuSeq_->data()); + GpuIVector gpuSeq(gpuSeq_->shape().getElements(), (int*)gpuSeq_->data()); + gpuSeq.copyFrom(cpuSeq); + } + + void addInputs(const SequenceArg& input) { + CHECK_EQ(input.shape().ndims(), 2UL); + size_t batchSize = input.shape()[0]; + if (!cpuSeq_ || !gpuSeq_) { // sequence not exist + addSequence(SequenceIdArg(TensorShape{batchSize})); + } + + size_t size = + input.shape().getElements() * sizeOfValuType(input.valueType()); + cpuMemory_.emplace_back(std::make_shared(size)); + gpuMemory_.emplace_back(std::make_shared(size)); + + /// SequenceArg + cpuInputs_.emplace_back( + std::make_shared(cpuMemory_.back()->getBuf(), + input.valueType(), + input.shape(), + *cpuSeq_)); + gpuInputs_.emplace_back( + std::make_shared(gpuMemory_.back()->getBuf(), + input.valueType(), + input.shape(), + *gpuSeq_)); + } + // output need only contains shape, do not contains data. void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) { size_t size = @@ -116,24 +164,31 @@ public: std::make_shared(*gpuSparse_, argType)); } - void addInputs(const SequenceArg& input) { - size_t batchSize = input.shape()[0]; - size_t numSeqs = batchSize / 10 + 1; - - size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32); - cpuMemory_.emplace_back(std::make_shared(sizeId)); - gpuMemory_.emplace_back(std::make_shared(sizeId)); - - TensorShape seqsId({numSeqs + 1}); - // void* cpuBuffer = cpuMemory_.back()->getBuf(); - // void* gpuBuffer = gpuMemory_.back()->getBuf(); + void addOutputs(const SequenceArg& output, ArgType argType = ASSIGN_TO) { + CHECK_EQ(output.shape().ndims(), 2UL); + size_t batchSize = output.shape()[0]; + if (!cpuSeq_ || !gpuSeq_) { // sequence not exist + addSequence(SequenceIdArg(TensorShape{batchSize})); + } size_t size = - input.shape().getElements() * sizeOfValuType(input.valueType()); + output.shape().getElements() * sizeOfValuType(output.valueType()); cpuMemory_.emplace_back(std::make_shared(size)); gpuMemory_.emplace_back(std::make_shared(size)); - // TODO: need be implemented. + /// SequenceArg + cpuOutputs_.emplace_back( + std::make_shared(cpuMemory_.back()->getBuf(), + output.valueType(), + output.shape(), + *cpuSeq_, + argType)); + gpuOutputs_.emplace_back( + std::make_shared(gpuMemory_.back()->getBuf(), + output.valueType(), + output.shape(), + *gpuSeq_, + argType)); } void addInputs(const SparseMatrixArg& input) { @@ -193,14 +248,44 @@ public: std::shared_ptr getGpuFunction() const { return gpuFunc_; } protected: + // only init cpu argument, gpu argument copy from cpu argument. + void initArg(BufferArg& arg) { + CpuVector vector(arg.shape().getElements(), (real*)arg.data()); + vector.uniform(0.001, 1); + } + + void initArg(SequenceArg& arg) { + /// init only matrix + CpuVector vector(arg.shape().getElements(), (real*)arg.data()); + vector.uniform(0.001, 1); + } + + void initArg(SequenceIdArg& arg, size_t batchSize) { + size_t numSeqs = arg.numSeqs(); + int* buf = reinterpret_cast(arg.data()); + int pos = 0; + size_t maxLen = 2 * batchSize / numSeqs; + for (int i = 0; i < (int)numSeqs; ++i) { + int len = 1 + uniformRandom(std::min( + maxLen, batchSize - pos - numSeqs + i)); + buf[i] = pos; + pos += len; + VLOG(1) << " len=" << len; + } + buf[numSeqs] = batchSize; + } + void initInputs() { for (size_t i = 0; i < cpuInputs_.size(); i++) { if (cpuInputs_[i]->isSparseArg()) { continue; /// sparse matrix already init } - initArg(*cpuInputs_[i]); - + if (cpuInputs_[i]->isSequenceArg()) { + initArg(dynamic_cast(*cpuInputs_[i])); + } else { + initArg(*cpuInputs_[i]); + } // TODO: Need a BufferCopy used to copy from one BufferArg to another. CpuVector cpuVector(cpuInputs_[i]->shape().getElements(), (real*)cpuInputs_[i]->data()); @@ -217,7 +302,11 @@ protected: continue; /// sparse matrix already init } - initArg(*cpuOutputs_[i]); + if (cpuOutputs_[i]->isSequenceArg()) { + initArg(dynamic_cast(*cpuOutputs_[i])); + } else { + initArg(*cpuOutputs_[i]); + } // TODO: Need a BufferCopy used to copy from one BufferArg to another. CpuVector cpuVector(cpuOutputs_[i]->shape().getElements(), @@ -241,28 +330,6 @@ protected: } } - // only init cpu argument, gpu argument copy from cpu argument. - void initArg(BufferArg& arg) { - CpuVector vector(arg.shape().getElements(), (real*)arg.data()); - vector.uniform(0.001, 1); - } - - void initArg(SequenceIdArg& arg, size_t batchSize) { - size_t numSeqs = arg.numSeqs(); - int* buf = reinterpret_cast(arg.data()); - int pos = 0; - size_t maxLen = 2 * batchSize / numSeqs; - for (int i = 0; i < (int)numSeqs; ++i) { - int len = uniformRandom( - std::min(maxLen, batchSize - pos - numSeqs + i)) + - 1; - buf[i] = pos; - pos += len; - VLOG(1) << " len=" << len; - } - buf[numSeqs] = batchSize; - } - protected: std::shared_ptr cpuFunc_; std::shared_ptr gpuFunc_; @@ -274,6 +341,8 @@ protected: std::vector gpuOutputs_; std::shared_ptr cpuSparse_; std::shared_ptr gpuSparse_; + std::shared_ptr cpuSeq_; + std::shared_ptr gpuSeq_; }; } // namespace paddle diff --git a/paddle/function/MulOpTest.cpp b/paddle/function/MulOpTest.cpp index 158c3c90983b12c352765479006669c5c9e5a8aa..8748eb0d79fa0fcb0935eac5bb37b44274128aa0 100644 --- a/paddle/function/MulOpTest.cpp +++ b/paddle/function/MulOpTest.cpp @@ -60,7 +60,7 @@ TEST(MulOp, DDDMatrixMul) { if (transa && transb) { continue; } - VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') + VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ') << " transa=" << transa << " transb=" << transb << " dimM=" << std::setw(5) << dimM << " dimN=" << std::setw(5) << dimN @@ -104,7 +104,7 @@ TEST(MuLOp, DSparseDMul) { for (const auto dimK : {3, 10}) { for (const auto nnz : {3, 10}) { for (const auto FORMAT : {SPARSE_CSR}) { - VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') + VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ') << " dimM=" << std::setw(5) << dimM << " dimN=" << std::setw(5) << dimN << " dimK=" << std::setw(5) << dimK @@ -150,7 +150,7 @@ TEST(MulOp, DDSparseMul) { for (const auto dimK : {3, 10}) { for (const auto nnz : {3, 10}) { for (const auto FORMAT : {SPARSE_CSR, SPARSE_CSC}) { - VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') + VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ') << " dimM=" << std::setw(5) << dimM << " dimN=" << std::setw(5) << dimN << " dimK=" << std::setw(5) << dimK @@ -197,7 +197,7 @@ TEST(MulOp, SparseDDMul) { for (const auto dimK : {3, 10}) { for (const auto nnz : {3, 10}) { for (const auto FORMAT : {SPARSE_CSC, SPARSE_CSR}) { - VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') + VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ') << " dimM=" << std::setw(5) << dimM << " dimN=" << std::setw(5) << dimN << " dimK=" << std::setw(5) << dimK diff --git a/paddle/gserver/dataproviders/PyDataProvider2.cpp b/paddle/gserver/dataproviders/PyDataProvider2.cpp index c26e242534f2afcff396762adb085bf99303e2b5..b8079dc0796d0e300e65ac6b6b8d3bc826b1e504 100644 --- a/paddle/gserver/dataproviders/PyDataProvider2.cpp +++ b/paddle/gserver/dataproviders/PyDataProvider2.cpp @@ -647,7 +647,7 @@ public: DataBatch& gpuBatch = *batch; std::vector& gpuArguments = gpuBatch.getStreams(); gpuArguments.resize(cpuArguments.size()); - gpuBatch.setSize(size); + gpuBatch.setSize(bsize); for (size_t i = 0; i < headers_.size(); ++i) { gpuArguments[i].resizeAndCopyFrom( cpuArguments[i], useGpu_, HPPL_STREAM_1); diff --git a/paddle/gserver/layers/CosSimLayer.cpp b/paddle/gserver/layers/CosSimLayer.cpp index 254120443dc3d41bf2422be2e88cb376d70c93d4..a6c0300acf6752a3536e7939577b561fd97d1eb8 100644 --- a/paddle/gserver/layers/CosSimLayer.cpp +++ b/paddle/gserver/layers/CosSimLayer.cpp @@ -26,15 +26,23 @@ bool CosSimLayer::init(const LayerMap& layerMap, Layer::init(layerMap, parameterMap); CHECK_EQ(inputLayers_.size(), 2LU); + + createFunction(forward_, + "CosSimForward", + FuncConfig().set("scale", (real)config_.cos_scale())); + createFunction(backward_, + "CosSimBackward", + FuncConfig().set("scale", (real)config_.cos_scale())); + return true; } void CosSimLayer::forward(PassType passType) { Layer::forward(passType); - /* malloc memory for the output_ if necessary */ int batchSize = getInputValue(0)->getHeight(); int size = getSize(); + CHECK_EQ(forward_.size(), 1) << "Only one forward function needed"; { REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str()); @@ -42,26 +50,43 @@ void CosSimLayer::forward(PassType passType) { } MatrixPtr outV = getOutputValue(); - /* activation */ { REGISTER_TIMER_INFO("CosFwAtvTimer", getName().c_str()); MatrixPtr prevOut1 = getInputValue(0); MatrixPtr prevOut2 = getInputValue(1); - outV->cosSim(*prevOut1, *prevOut2, config_.cos_scale()); + + CHECK(outV && prevOut1 && prevOut2); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*prevOut1); + inputs.addArg(*prevOut2); + outputs.addArg(*outV, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); } } void CosSimLayer::backward(const UpdateCallback& callback) { /* activation */ { REGISTER_TIMER_INFO("CosBpAtvTimer", getName().c_str()); - MatrixPtr outG = this->getOutputGrad(); - - outG->cosSimDerivative(*this->getOutputValue(), - *getInputValue(0), - *getInputValue(1), - *getInputGrad(0), - *getInputGrad(1), - config_.cos_scale()); + CHECK_EQ(backward_.size(), 1) << "Only one backward function needed"; + + const auto outG = this->getOutputGrad(); + const auto outV = this->getOutputValue(); + const auto inV1 = this->getInputValue(0); + const auto inV2 = this->getInputValue(1); + auto inG1 = this->getInputGrad(0); + auto inG2 = this->getInputGrad(1); + CHECK(outG && outV && inV1 && inV2 && inG1 && inG2); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*outG); + inputs.addArg(*outV); + inputs.addArg(*inV1); + inputs.addArg(*inV2); + outputs.addArg(*inG1, ADD_TO); + outputs.addArg(*inG2, ADD_TO); + + backward_[0]->calc(inputs, outputs); } } diff --git a/paddle/gserver/layers/CosSimLayer.h b/paddle/gserver/layers/CosSimLayer.h index 65549626098f084c5e1786885e06c1bdfa3ba74c..8afaee62c2dcacba006846df0111fcbe8f7575e4 100644 --- a/paddle/gserver/layers/CosSimLayer.h +++ b/paddle/gserver/layers/CosSimLayer.h @@ -28,7 +28,7 @@ namespace paddle { * * - Input1: A vector (batchSize * dataDim) * * - Input2: A vector (batchSize * dataDim) or (1 * dataDim) * - * - Output: A vector (dataDim * 1) + * - Output: A vector (batchSize * 1) * * The config file api is cos_sim. */ diff --git a/paddle/gserver/layers/CosSimVecMatLayer.cpp b/paddle/gserver/layers/CosSimVecMatLayer.cpp index 5f652319e5620227fca166a8f72e5aed416bf5dd..aabafd473aa1e06a767d48d4c49b7b8662e992e7 100644 --- a/paddle/gserver/layers/CosSimVecMatLayer.cpp +++ b/paddle/gserver/layers/CosSimVecMatLayer.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/utils/Stat.h" namespace paddle { - /** * @brief A layer for computing cosine similarity between a vector * and each row of a matrix @@ -98,11 +97,22 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap, dataDim, /* trans= */ false, useGpu_); + + CHECK(tmpRow0 && tmpRow1 && tmpRow2 && tmpRow3 && tmpMtx0 && tmpMtx1); + + createFunction(forward_, + "CosSimForward", + FuncConfig().set("scale", (real)config_.cos_scale())); + createFunction(backward_, + "CosSimBackward", + FuncConfig().set("scale", (real)config_.cos_scale())); + return true; } void CosSimVecMatLayer::forward(PassType passType) { Layer::forward(passType); + CHECK_EQ(forward_.size(), 1) << "Only one forward function needed"; MatrixPtr inV0 = getInputValue(0); MatrixPtr inV1 = getInputValue(1); @@ -118,17 +128,25 @@ void CosSimVecMatLayer::forward(PassType passType) { } MatrixPtr outV = getOutputValue(); - + CHECK(outV && inV0 && inV1); REGISTER_TIMER_INFO("FwCosVMTimer", getName().c_str()); for (size_t i = 0; i < batchSize; i++) { tmpRow0->setData(inV0->rowBuf(i)); tmpMtx0->setData(inV1->rowBuf(i)); tmpRow2->setData(outV->rowBuf(i)); - tmpRow2->cosSim(*(tmpMtx0), *(tmpRow0), config_.cos_scale()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*tmpMtx0); + inputs.addArg(*tmpRow0); + outputs.addArg(*tmpRow2, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); } } void CosSimVecMatLayer::backward(const UpdateCallback& callback) { + CHECK_EQ(backward_.size(), 1) << "Only one forward function needed"; + MatrixPtr inV0 = getInputValue(0); MatrixPtr inV1 = getInputValue(1); MatrixPtr inG0 = getInputGrad(0); @@ -137,27 +155,27 @@ void CosSimVecMatLayer::backward(const UpdateCallback& callback) { MatrixPtr outG = getOutputGrad(); size_t batchSize = inV0->getHeight(); - + CHECK(inV0 && inV1 && inG0 && inG1 && outV && outG); REGISTER_TIMER_INFO("BwCosVMTimer", getName().c_str()); - if (inG0 && inG1) { - for (size_t i = 0; i < batchSize; i++) { - tmpRow0->setData(inV0->rowBuf(i)); - tmpRow1->setData(inG0->rowBuf(i)); - tmpMtx0->setData(inV1->rowBuf(i)); - tmpMtx1->setData(inG1->rowBuf(i)); - tmpRow2->setData(outV->rowBuf(i)); - tmpRow3->setData(outG->rowBuf(i)); - - tmpRow3->cosSimDerivative(*(tmpRow2), - *(tmpMtx0), - *(tmpRow0), - *(tmpMtx1), - *(tmpRow1), - config_.cos_scale()); - } - } else { - CHECK(!inG0 || !inG1) << "Not supported"; + for (size_t i = 0; i < batchSize; i++) { + tmpRow0->setData(inV0->rowBuf(i)); + tmpRow1->setData(inG0->rowBuf(i)); + tmpMtx0->setData(inV1->rowBuf(i)); + tmpMtx1->setData(inG1->rowBuf(i)); + tmpRow2->setData(outV->rowBuf(i)); + tmpRow3->setData(outG->rowBuf(i)); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*tmpRow3); + inputs.addArg(*tmpRow2); + inputs.addArg(*tmpMtx0); + inputs.addArg(*tmpRow0); + outputs.addArg(*tmpMtx1, ADD_TO); + outputs.addArg(*tmpRow1, ADD_TO); + + backward_[0]->calc(inputs, outputs); } } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index a8b53e2105b053399e62fba5321fd22c1fe4a50d..1964b2f8bfaebc49fe3073e03c949a8a9c3e385a 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -941,59 +941,6 @@ void GpuMatrix::softreluDerivative(Matrix& output) { void GpuMatrix::scaledTanh(Matrix& output, real p1, real p2) { BaseMatrix::scaledTanh(output, p1, p2); } -void GpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) { - CHECK(output1.useGpu_ == true && output2.useGpu_ == true) - << "Matrix type are not equal"; - size_t numSamples = getHeight(); - size_t dim = output1.getWidth(); - CHECK_EQ(getWidth(), 1UL); - CHECK_EQ(output1.getHeight(), numSamples); - CHECK_EQ(output1.getWidth(), output2.getWidth()); - real* out = getData(); - real* x = output1.getData(); - real* y = output2.getData(); - hl_cossim(out, x, y, dim, output1.getHeight(), output2.getHeight(), scale); -} -void GpuMatrix::cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale) { - CHECK(output.useGpu_ == true && prevOut1.useGpu_ == true && - prevOut2.useGpu_ == true && prevGrad1.useGpu_ == true && - prevGrad2.useGpu_ == true) - << "Matrix type are not equal"; - CHECK_EQ(getWidth(), 1UL); - CHECK_EQ(output.getWidth(), 1UL); - - size_t numSamples = getHeight(); - CHECK_EQ(output.getHeight(), numSamples); - CHECK_EQ(prevOut1.getHeight(), numSamples); - CHECK_EQ(prevGrad1.getHeight(), numSamples); - - size_t dim = prevOut1.getWidth(); - CHECK_EQ(prevOut2.getWidth(), dim); - CHECK_EQ(prevGrad1.getWidth(), dim); - CHECK_EQ(prevGrad2.getWidth(), dim); - - real* grad = getData(); - real* out = output.getData(); - real* prevOutX = prevOut1.getData(); - real* prevOutY = prevOut2.getData(); - real* prevGradX = prevGrad1.getData(); - real* prevGradY = prevGrad2.getData(); - hl_cossim_derivative(grad, - out, - prevOutX, - prevOutY, - prevGradX, - prevGradY, - dim, - prevOut1.getHeight(), - prevOut2.getHeight(), - scale); -} void GpuMatrix::randomizeUniform() { CHECK(isContiguous()); @@ -3470,105 +3417,6 @@ void CpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) { } } -void CpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) { - size_t numSamples = getHeight(); - size_t dim = output1.getWidth(); - CHECK_EQ(getWidth(), 1UL); - CHECK_EQ(output1.getHeight(), numSamples); - CHECK_EQ(output1.getWidth(), output2.getWidth()); - - real* out = getData(); - const real* x = output1.getData(); - const real* y = output2.getData(); - size_t yInc = dim; - if (output2.getHeight() == 1LU) { - yInc = 0; - } else { - CHECK_EQ(output2.getHeight(), numSamples); - } - for (size_t i = 0; i < numSamples; ++i, x += dim, y += yInc) { - real squareSumX = 0; - real squareSumY = 0; - real xy = 0; - for (size_t j = 0; j < dim; ++j) { - squareSumX += _square(x[j]); - squareSumY += _square(y[j]); - xy += x[j] * y[j]; - } - CHECK(squareSumX > 0 && squareSumY > 0); - out[i] = scale * xy / (std::sqrt(squareSumX) * std::sqrt(squareSumY)); - } -} - -void CpuMatrix::cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale) { - CHECK(output.useGpu_ == false) << "Matrix type are not equal"; - - CHECK_EQ(getWidth(), 1UL); - CHECK_EQ(output.getWidth(), 1UL); - - size_t numSamples = getHeight(); - CHECK_EQ(output.getHeight(), numSamples); - CHECK_EQ(prevOut1.getHeight(), numSamples); - CHECK_EQ(prevGrad1.getHeight(), numSamples); - - size_t dim = prevOut1.getWidth(); - CHECK_EQ(prevOut2.getWidth(), dim); - CHECK_EQ(prevGrad1.getWidth(), dim); - CHECK_EQ(prevGrad2.getWidth(), dim); - - const real* grad = getData(); - const real* out = output.getData(); - const real* prevOutX = prevOut1.getData(); - const real* prevOutY = prevOut2.getData(); - real* prevGradX = prevGrad1.getData(); - real* prevGradY = prevGrad2.getData(); - size_t yInc = dim; - if (prevOut2.getHeight() == 1LU) { - yInc = 0; - CHECK_EQ(prevGrad2.getHeight(), 1LU); - } else { - CHECK_EQ(prevOut2.getHeight(), numSamples); - CHECK_EQ(prevGrad2.getHeight(), numSamples); - } - for (size_t i = 0; i < numSamples; ++i, - prevOutX += dim, - prevOutY += yInc, - prevGradX += dim, - prevGradY += yInc) { - real squareSumX = 0; - real squareSumY = 0; - real xy = 0; - for (size_t j = 0; j < dim; ++j) { - squareSumX += _square(prevOutX[j]); - squareSumY += _square(prevOutY[j]); - xy += prevOutX[j] * prevOutY[j]; - } - CHECK(squareSumX > 0 && squareSumY > 0); - if (xy == 0) { - real reciprocal = 1.0f / (std::sqrt(squareSumX) * std::sqrt(squareSumY)); - for (size_t j = 0; j < dim; ++j) { - prevGradX[j] += scale * grad[i] * prevOutY[j] * reciprocal; - prevGradY[j] += scale * grad[i] * prevOutX[j] * reciprocal; - } - } else { - real reciprocalXY = 1.0f / xy; - real reciprocalSquareSumX = 1.0f / squareSumX; - real reciprocalSquareSumY = 1.0f / squareSumY; - for (size_t j = 0; j < dim; ++j) { - prevGradX[j] += out[i] * grad[i] * (prevOutY[j] * reciprocalXY - - prevOutX[j] * reciprocalSquareSumX); - prevGradY[j] += out[i] * grad[i] * (prevOutX[j] * reciprocalXY - - prevOutY[j] * reciprocalSquareSumY); - } - } - } -} - void CpuMatrix::sumOfSquares(Matrix& output, Matrix& label) { CHECK(output.useGpu_ == false && label.useGpu_ == false) << "Matrix type are not equal"; diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index c92c0a272d5a72868bd61035d77aa4ed0fad7a7c..ea4bbb86b057b526c5ea294b2cd835aef65de58d 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -799,26 +799,6 @@ public: LOG(FATAL) << "Not implemented"; } - /** - * cosine similarity, for each row i, - * this[i] = cos(output1[i], output2[i]) - * - * output2 can only have one row, then for each row i, - * this[i] = cos(output1[i], output2[0]) - */ - virtual void cosSim(Matrix& output1, Matrix& output2, real scale = 1.0f) { - LOG(FATAL) << "Not implemented"; - } - - virtual void cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale = 1.0f) { - LOG(FATAL) << "Not implemented"; - } - /// print out the values of elements to os virtual void print(std::ostream& os) const { LOG(FATAL) << "Not implemented"; @@ -1324,14 +1304,6 @@ public: void softreluDerivative(Matrix& output); void scaledTanh(Matrix& output, real p1, real p2); - void cosSim(Matrix& output1, Matrix& output2, real scale); - void cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale); - virtual void print(std::ostream& os) const; virtual void print(std::ostream& os, size_t height, size_t width) const; @@ -1752,14 +1724,6 @@ public: void softreluDerivative(Matrix& output); void scaledTanh(Matrix& output, real p1, real p2); - void cosSim(Matrix& output1, Matrix& output2, real scale); - void cosSimDerivative(Matrix& output, - Matrix& prevOut1, - Matrix& prevOut2, - Matrix& prevGrad1, - Matrix& prevGrad2, - real scale); - void print(std::ostream& os) const; void print(std::ostream& os, size_t height, size_t width) const; void printOneRow(std::ostream& os, size_t idx) const; diff --git a/paddle/math/tests/test_Matrix.cpp b/paddle/math/tests/test_Matrix.cpp index a4084bdf7c6953651bfd9714fd8a5c930f774fe6..1c21da5b76e95603258a5006d0c57b00126e65b9 100644 --- a/paddle/math/tests/test_Matrix.cpp +++ b/paddle/math/tests/test_Matrix.cpp @@ -181,28 +181,6 @@ TEST(Matrix, copyByRowIndex) { } } -void testCosSim(int heightX, int heightY, int width, real scale) { - AutoCompare test(heightX, 1); - CpuMatrix arg1(heightX, width); - CpuMatrix arg2(heightY, width); - arg1.randomizeUniform(); - arg2.randomizeUniform(); - arg2.add(-0.5); - test.cmpWithArg(&Matrix::cosSim, arg1, arg2, scale); -} - -TEST(Matrix, cosSim) { - for (auto heightX : {10, 100, 1000}) { - for (auto heightY : {1, heightX}) { - for (auto width : {10, 100, 1000}) { - for (auto scale : {1.0, 2.0}) { - testCosSim(heightX, heightY, width, scale); - } - } - } - } -} - void testParamReluForward(int height, int width, int w_height, int w_width) { AutoCompare test(height, width); CpuMatrix arg1(height, width); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index e024f2cf1b913f56301ac7b3380f0c382818f413..6caaea443c1df756bfeb775154e8a90400cc3211 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -720,61 +720,6 @@ TEST(Matrix, sequenceAvgForward) { } } -void testCosSimDerivate(int heightX, int heightY, int width, real scale) { - MatrixPtr prevOutX = CpuMatrix::create(heightX, width, false, false); - MatrixPtr prevOutY = CpuMatrix::create(heightY, width, false, false); - MatrixPtr grad = CpuMatrix::create(heightX, 1, false, false); - MatrixPtr output = CpuMatrix::create(heightX, 1, false, false); - MatrixPtr prevGradX = CpuMatrix::create(heightX, width, false, false); - MatrixPtr prevGradY = CpuMatrix::create(heightY, width, false, false); - - prevOutX->randomizeUniform(); - prevOutY->randomizeUniform(); - grad->randomizeUniform(); - output->randomizeUniform(); - prevGradX->randomizeUniform(); - prevGradY->randomizeUniform(); - - MatrixPtr prevOutXGpu = GpuMatrix::create(heightX, width, false, true); - MatrixPtr prevOutYGpu = GpuMatrix::create(heightY, width, false, true); - MatrixPtr gradGpu = GpuMatrix::create(heightX, 1, false, true); - MatrixPtr outputGpu = GpuMatrix::create(heightX, 1, false, true); - MatrixPtr prevGradXGpu = GpuMatrix::create(heightX, width, false, true); - MatrixPtr prevGradYGpu = GpuMatrix::create(heightY, width, false, true); - - prevOutXGpu->copyFrom(*prevOutX); - prevOutYGpu->copyFrom(*prevOutY); - gradGpu->copyFrom(*grad); - outputGpu->copyFrom(*output); - prevGradXGpu->copyFrom(*prevGradX); - prevGradYGpu->copyFrom(*prevGradY); - - grad->cosSimDerivative( - *output, *prevOutX, *prevOutY, *prevGradX, *prevGradY, scale); - - gradGpu->cosSimDerivative(*outputGpu, - *prevOutXGpu, - *prevOutYGpu, - *prevGradXGpu, - *prevGradYGpu, - scale); - - TensorCheckErr(*prevGradX, *prevGradXGpu); - TensorCheckErr(*prevGradY, *prevGradYGpu); -} - -TEST(Matrix, cosSimDerivate) { - for (auto heightX : {1, 10, 100}) { - for (auto heightY : {1, heightX}) { - for (auto width : {1, 10, 100}) { - for (auto scale : {1.0, 2.0}) { - testCosSimDerivate(heightX, heightY, width, scale); - } - } - } - } -} - void testParamReluBackwardDiff(int height, int width, int w_height, diff --git a/paddle/utils/Util.cpp b/paddle/utils/Util.cpp index 220aac1ff11e0ff263df8459f539237944b94c81..dbab4ec43ca2fa691445131d2cb14f51721a2e4c 100644 --- a/paddle/utils/Util.cpp +++ b/paddle/utils/Util.cpp @@ -289,6 +289,7 @@ void mkDir(const char* filename) { void mkDirRecursively(const char* dir) { struct stat sb; + if (*dir == 0) return; // empty string if (!stat(dir, &sb)) return; mkDirRecursively(path::dirname(dir).c_str()); diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py index 0ea8fc77eef9f5daeaa262ce7808db3e980f991c..ab9a2562dcccb394c0b24741ceeb10061e40cb9a 100644 --- a/python/paddle/trainer_config_helpers/data_sources.py +++ b/python/paddle/trainer_config_helpers/data_sources.py @@ -201,7 +201,7 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None): data.load_data_module = load_data_module data.load_data_object = load_data_object data.load_data_args = load_data_args - data.async_load_data = True + data.async_load_data = False return data define_py_data_sources( diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 66fa58ac91e33bfeac37d1bfbdad8dab4789c4bd..1fdc4c462363712e8b5b4dee10d0aaa26f4deffa 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -3677,26 +3677,27 @@ def pad_layer(input, For example, - .. code-block:: - - input(2,2,2,3) = [ - [ [[1,2,3], [3,4,5]], - [[2,3,5], [1,6,7]] ], - [ [[4,3,1], [1,8,7]], - [[3,8,9], [2,3,5]] ] - ] - - pad_c=[1,1], pad_h=[0,0], pad_w=[0,0] - output(2,4,2,3) = [ - [ [[0,0,0], [0,0,0]], - [[1,2,3], [3,4,5]], - [[2,3,5], [1,6,7]], - [[0,0,0], [0,0,0]] ], - [ [[0,0,0], [0,0,0]], - [[4,3,1], [1,8,7]], - [[3,8,9], [2,3,5]], - [[0,0,0], [0,0,0]] ] - ] + .. code-block:: python + + input(2,2,2,3) = [ + [ [[1,2,3], [3,4,5]], + [[2,3,5], [1,6,7]] ], + [ [[4,3,1], [1,8,7]], + [[3,8,9], [2,3,5]] ] + ] + + pad_c=[1,1], pad_h=[0,0], pad_w=[0,0] + + output(2,4,2,3) = [ + [ [[0,0,0], [0,0,0]], + [[1,2,3], [3,4,5]], + [[2,3,5], [1,6,7]], + [[0,0,0], [0,0,0]] ], + [ [[0,0,0], [0,0,0]], + [[4,3,1], [1,8,7]], + [[3,8,9], [2,3,5]], + [[0,0,0], [0,0,0]] ] + ] The simply usage is: @@ -4191,13 +4192,7 @@ def block_expand_layer(input, @wrap_name_default() @layer_support() -def maxout_layer(input, - groups, - num_channels=None, - size_x=None, - size_y=None, - name=None, - layer_attr=None): +def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None): """ A layer to do max out on conv layer output. - Input: output of a conv layer. @@ -4227,12 +4222,6 @@ def maxout_layer(input, :type num_channels: int|None :param groups: The group number of input layer. :type groups: int - :param size_x: conv output width. If None will be set - automatically from previous output. - :type size_x: int|None - :param size_y: conv output height. If None will be set - automatically from previous output. - :type size_y: int|None :param name: The name of this layer, which can not specify. :type name: None|basestring. :param layer_attr: Extra Layer attribute. diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr index 1cfb92255aa92fa3fbc16a816851a5c2f81c2b56..569b0b945a762e8b596e197adc06df64e33311af 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr @@ -19,7 +19,7 @@ model_config { data_config { type: "py2" files: "train.list" - async_load_data: true + async_load_data: false for_test: false load_data_module: "a" load_data_object: "c" @@ -58,7 +58,7 @@ opt_config { test_data_config { type: "py2" files: "test.list" - async_load_data: true + async_load_data: false for_test: true load_data_module: "b" load_data_object: "d"