diff --git a/.travis.yml b/.travis.yml
index 162bebba091d84b295f929527de9804e65df5a65..5d82d9729b75ef493a0bd03921c453f9a519c8cd 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -54,7 +54,9 @@ before_install:
fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
- - pip install numpy wheel protobuf sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker
+ # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
+ # protobuf version.
+ - pip install numpy wheel 'protobuf==3.1' sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker
script:
- paddle/scripts/travis/main.sh
notifications:
diff --git a/doc/design/api.md b/doc/design/api.md
new file mode 100644
index 0000000000000000000000000000000000000000..8185d2af0ea264a2e7b4e28b9ed05279e4a22014
--- /dev/null
+++ b/doc/design/api.md
@@ -0,0 +1,262 @@
+# PaddlePaddle Design Doc
+
+## Ingredients
+
+As our design principle is starting from the essence: how could we
+allow users to express and solve their problems at neural networks.
+Some essential concepts that our API have to provide include:
+
+1. A *topology* is an expression of *layers*.
+
+1. A layer could be any kind of computation, including *cost*.
+
+1. Some layers have parameters, some don't. Most costs don't have
+ parameters.
+
+1. In some topologies, layers share parameters. For
+ example,
+ [the network for training a ranking model](https://github.com/PaddlePaddle/Paddle/issues/1311#issuecomment-279121850).
+
+1. At programming time, users specify topologies and possible sharing
+ of parameters. PaddlePaddle can figure out and create parameters
+ required (and possibly shared) by one or more topologies.
+
+
+## Starting from Examples
+
+As a summarization
+of
+[our disucssion](https://github.com/PaddlePaddle/Paddle/issues/1315),
+let us present two examples here:
+
+
+### Example 1. Sharing Parameters between Layers
+
+We use
+the
+[3-branch ranking](https://github.com/PaddlePaddle/Paddle/issues/1311#issuecomment-279121850) model
+in this example. For your convenience, I copy-a-paste the model's
+topology as follows:
+
+```
+A -> f -\
+Q -> f --> cost
+B -> f -/
+```
+
+The following program trains the topology including the cost, and then
+use the sub-network in the trained topology in inference:
+
+```python
+def f(in):
+ e = paddle.layer.embedding(in, parameter_name="embedding")
+ o = paddle.layer.softmax(e, parameter_name="semantic")
+ return o
+
+# Create 3 topologies (subnets), they share parameters because all
+# correspoinding layers have the same parameter names.
+fA = f(paddle.layer.data(input_name="A"))
+fB = f(paddle.layer.data(input_name="B"))
+fQ = f(paddle.layer.data(input_name="Q"))
+
+topology = paddle.layer.less_than(
+ paddle.layer.cross_entropy(fA, fQ),
+ paddle.layer.corss_entropy(fB, fQ))
+
+# Derive parameters required in topology and create them in model.
+parameters = paddle.parameters.create(topology)
+
+# Estimate parameters used in topology from data.
+paddle.train(topology, parameters, reader=read_ranking_model_data)
+
+# Inference using fA (or fB or fC, as they share their parameters).
+[testA, testB, testQ] = read_ranking_model_data()
+print "The sematic-vector of testA: ", paddle.infer(fA, parameters, testA)
+```
+
+
+### Example 2. Sharing Parameters between "Models"
+
+We use [GAN](https://github.com/PaddlePaddle/book/tree/develop/gan) in
+this example. In the following example program, `d0` and `d1`
+correspond to the two networks in the following figure:
+
+
+
+```python
+def G(in):
+ # over-simplified example as G has only one layers:
+ return paddle.layer.fc(in, parameter_name="G")
+
+def D(in);
+ # again, over-simplified:
+ return paddle.layer.fc(in, parameter_name="D")
+
+# Construct the first topology, which contains both D and G.
+# By learning this topology, we update parameters of G.
+d0 = paddle.layer.should_be_false(D(G(paddle.layer.data())))
+
+# Construct a second topology d1, which contains only D. By
+# training this topology, we update parameters of D. Note
+# that d1 share parameters with d0.
+d1 = paddle.layer.should_be_true(D(paddle.layer.data()))
+
+# Create parameters from a list of multiple topologies (models) for
+# the chance to share parameters between these topologies.
+parameters = paddle.parameters.create([d0, d1])
+
+# Iterative training of GAN.
+for ...:
+ train(d0, parameters, reader=read_from_rng, immutable_parameters={"D"})
+ train(d1, parameters, reader=read_from_realistic_images)
+
+# Use d1 for inference:
+print "D thinks a batch of images are realistic ", infer(d1, parameters, read_mnist_images)
+```
+
+
+### Summarization
+
+
+Above two programs reveal some important design concerns:
+
+1. Users describe a topology as an expression of layers. Every layer
+ has a *parameter name*. If the users don't specify it explicitly, it's automatically generated as a unique name. By
+ specifying the parameter name, users can specify the sharing of
+ parameters between layers and even between topologies.
+
+1. `paddle.parameters.create` figures out parameters required by one
+ or more topologies from parameter names of layers. It creates these
+ parameters and returns a `ParameterSet` object, which is in essence
+ a map from *parameter names* to *parameters*.
+
+1. At training and inference time, `paddle.train` and `paddle.infer`
+ requires both a topology and the parameter set that holds the parameters of that topology. There are some reasons:
+
+ 1. This prevents users from forgetting to call
+ `paddle.parameters.create`.
+ 1. `paddle.train` needs to know which parameter set to update.
+ 1. Users could load another (pre-trained) parameter set and use it
+ with a topology in `train.infer`.
+
+1. By specifying the `immutable_parameters` parameter of
+ `paddle.train`, we can forbid the update of these parameters.
+
+
+## Reader
+
+Not all programming frameworks allow users to define I/O functions.
+An example is Google MapReduce, which can only read from text,
+SSTable, and RecordIO files. Hadoop MapReduce allows users to define
+readers and writers by deriving from base classes `Reader` and
+`Writer`. The former is less flexible but also less error-prone. We
+decide to provide the flexibility to users to define their readers.
+
+
+There are some open questions here:
+
+1. **Should a reader return a Python dictionary?**
+
+1. **How to map multiple outputs from a reader to multiple data layers?**
+
+1. **How to easily compose some existing readers to read more data and
+ feed a topology with more data layers?**
+
+
+## Training
+
+The recommended way to training a model is to call `paddle.train`,
+which simply calls `paddle.trainer.Default`, a global variable of
+type `paddle.trainer.SGD`. Equivalently, we can do
+
+```python
+opt = paddle.trainer.SGD(..., paddle.updater.Adam(...))
+opt.train(topology, parameters, reader=read, ...)
+```
+
+### Updater
+
+Please be aware that a trainer can accept an updater as its data
+member, where an updater is a class derived from
+`paddle.trainer.Updater`. This is to make it easier to customize
+trainers, as discussed
+[here](https://github.com/PaddlePaddle/Paddle/issues/1319).
+
+### Event Handler
+
+`paddle.train` and `paddle.trainer.XXX.train` take an optional
+parameter `event_handler`, which should be either `None` or a function
+that handle some events:
+
+1. BeginTraining
+1. EndTraining
+1. BeginIteration
+1. EndIteration
+1. BeginPass
+1. EndPass
+
+where EndPass is sent if and only if the reader yields
+`end_pass=True`.
+
+An example as follows:
+
+```python
+def event_handler(event):
+ if ininstance(event, paddle.event.EndIteration):
+ print paddle.test(...)
+
+paddle.train(topology, parameters, reader, event_handler)
+```
+
+If we are writing a PaddlePaddle program in and for iPython/Jypyter,
+we can use metaplotlib in the event handler to plot a curve of
+cost/error versus iterations, as shown
+[here](https://blog.dominodatalab.com/interactive-dashboards-in-jupyter/).
+
+### Distributed Training
+
+If users want to do distributed training on a cluster, s/he should
+call `paddle.dist_train` and provides access tokens to the cluster as
+a parameter.
+
+For example, if the user has a TLS certificate that allows him to
+access a Kubernetes cluster, s/he should be able to call
+
+```python
+paddle.dist_train(model,
+ trainer=paddle.trainer.SGD(...,
+ paddle.updater.Adam(...)),
+ reader=read,
+ k8s_user="yi",
+ k8s_token="kube_cluster_tls.pem",
+ k8s_job="hello",
+ num_parameter_servers=15)
+```
+
+The pseudo code if `paddle.dist_train` is as follows:
+
+```python
+def dist_train(topology, parameters, trainer, reader, ...):
+ if os.getenv("KUBERNETES_SERVICE_HOST") == None:
+ image_name = k8s_user + '/' + k8s_job
+ docker_build(image_name)
+ docker_push()
+ kube_ctrl_start_job(image_name, k8s_user, k8s_token)
+ else:
+ rank = kube_list_containers_in_job_and_return_current_containers_rank()
+ if rank == 0:
+ master()
+ elif rank < 15:
+ parameter_server()
+ else:
+ trainer.train(model, reader=read)
+```
+
+Please be aware that if a process is running on the Kubernetes
+cluster, it will have some environment variables pre-defined.
+
+If `dist_train` doesn't see these environment variables, it knows
+that it's running on users' personal computer, and it should work as a
+*launcher*. Otherwise, it knows that it's running on the cluster and
+need to figure out its role as either the master, or a trainer, or a
+parameter server.
diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h
index 40828dd5cc76f4197e6cfbb1121f2eef2c1ac580..6f21b82afdc6cdde785fdd8f13eef17a0fdd6324 100644
--- a/paddle/cuda/include/hl_matrix.h
+++ b/paddle/cuda/include/hl_matrix.h
@@ -188,48 +188,6 @@ extern void hl_param_relu_backward_diff(real* grad_o,
int width,
int height,
int partial_sum);
-/**
- * @brief cos sim forward
- *
- * @param[out] output output data
- * @param[in] input1 input1 data(matrix)
- * @param[in] input2 input2 data(matrix or vector)
- * @param[in] width matrix width
- * @param[in] input1_height input1_height
- * @param[in] input2_height input2_height
- * @param[in] scale scale factor
- */
-extern void hl_cossim(real* output,
- real* input1,
- real* input2,
- int width,
- int input1_height,
- int input2_height,
- real scale);
-/**
- * @brief cos sim derivate
- *
- * @param[in] grad output grad
- * @param[in] output output data
- * @param[in] prevOutX input1 data
- * @param[in] prevOutY input2 data
- * @param[out] prevGradX input1 grad
- * @param[out] prevGradY input2 grad
- * @param[in] width matrix width
- * @param[in] input1_height input1 height
- * @param[in] input2_height input2 height
- * @param[in] scale scale factor
- */
-extern void hl_cossim_derivative(real* grad,
- real* output,
- real* prevOutX,
- real* prevOutY,
- real* prevGradX,
- real* prevGradY,
- int width,
- int input1_height,
- int input2_height,
- real scale);
/**
* @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel].
diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h
index a1712d1e4d2a5dc80526b7d7b5ad7bd4f5d8c1ed..f4e6461cdcf198637b2c96fee88d1de2766aaf18 100644
--- a/paddle/cuda/include/stub/hl_matrix_stub.h
+++ b/paddle/cuda/include/stub/hl_matrix_stub.h
@@ -74,25 +74,6 @@ inline void hl_param_relu_backward_diff(real* grad_o,
int height,
int partial_sum) {}
-inline void hl_cossim(real* output,
- real* input1,
- real* input2,
- int width,
- int input1_height,
- int input2_height,
- real scale) {}
-
-inline void hl_cossim_derivative(real* grad,
- real* output,
- real* prevOutX,
- real* prevOutY,
- real* prevGradX,
- real* prevGradY,
- int width,
- int input1_height,
- int input2_height,
- real scale) {}
-
inline void hl_matrix_add_shared_bias(real* A_d,
real* B_d,
const int channel,
diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu
index cd23bd31057c5c8cd10173bc5fa5fa67f2d0e422..96c07d9c3b7a37daa9198fd7ea66b7d811600348 100644
--- a/paddle/cuda/src/hl_cuda_matrix.cu
+++ b/paddle/cuda/src/hl_cuda_matrix.cu
@@ -584,177 +584,6 @@ void hl_param_relu_backward_diff(real* grad_o,
CHECK_SYNC("hl_param_relu_backward_diff failed");
}
-template
-__global__ void KeCosSim(real* output,
- real* input1,
- real* input2,
- int width,
- int input1_height,
- int input2_height,
- real scale) {
- const int ty = blockIdx.y;
- int tid = threadIdx.x;
-
- __shared__ real xx[blockSize];
- __shared__ real yy[blockSize];
- __shared__ real xy[blockSize];
-
- xx[tid] = 0.0;
- yy[tid] = 0.0;
- xy[tid] = 0.0;
- __syncthreads();
-
- input1 += ty * width;
- if (input2_height > 1) {
- input2 += ty * width;
- }
- for (int index = tid; index < width; index += blockSize) {
- real x = input1[index];
- real y = input2[index];
- xx[tid] += x * x;
- yy[tid] += y * y;
- xy[tid] += x * y;
- }
- __syncthreads();
-
- for (int s = blockSize / 2; s > 0; s >>= 1) {
- if (tid < s) {
- xx[tid] += xx[tid + s];
- yy[tid] += yy[tid + s];
- xy[tid] += xy[tid + s];
- }
- __syncthreads();
- }
- if (tid == 0) {
- output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
- }
-}
-
-void hl_cossim(real* output,
- real* input1,
- real* input2,
- int width,
- int input1_height,
- int input2_height,
- real scale) {
- CHECK_NOTNULL(output);
- CHECK_NOTNULL(input1);
- CHECK_NOTNULL(input2);
- const int blockSize = 256;
- dim3 threads(blockSize, 1);
- dim3 grid(1, input1_height);
-
- KeCosSim<<>>
- (output, input1, input2, width, input1_height, input2_height, scale);
- CHECK_SYNC("hl_cossim failed");
-}
-
-template
-__global__ void KeCosSimDerivative(real* grad,
- real* output,
- real* prevOutX,
- real* prevOutY,
- real* prevGradX,
- real* prevGradY,
- int width,
- int input1_height,
- int input2_height,
- real scale) {
- const int ty = blockIdx.y;
- int tid = threadIdx.x;
-
- __shared__ real xx[blockSize];
- __shared__ real yy[blockSize];
- __shared__ real xy[blockSize];
-
- xx[tid] = 0.0;
- yy[tid] = 0.0;
- xy[tid] = 0.0;
- __syncthreads();
-
- prevOutX += ty * width;
- prevGradX += ty * width;
- if (input2_height > 1) {
- prevOutY += ty * width;
- prevGradY += ty * width;
- }
- for (int index = tid; index < width; index += blockSize) {
- real x = prevOutX[index];
- real y = prevOutY[index];
- xx[tid] += x * x;
- yy[tid] += y * y;
- xy[tid] += x * y;
- }
- __syncthreads();
-
- for (int s = blockSize / 2; s > 0; s >>= 1) {
- if (tid < s) {
- xx[tid] += xx[tid + s];
- yy[tid] += yy[tid + s];
- xy[tid] += xy[tid + s];
- }
- __syncthreads();
- }
- if (xy[0] == 0) {
- real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
- for (int index = tid; index < width; index += blockSize) {
- prevGradX[index] +=
- scale * grad[ty] * prevOutY[index] * reciprocal;
- if (input2_height > 1) {
- prevGradY[index] +=
- scale * grad[ty] * prevOutX[index] * reciprocal;
- } else {
- paddle::paddleAtomicAdd(prevGradY + index,
- scale * grad[ty] * prevOutX[index] * reciprocal);
- }
- }
- } else {
- real reciprocalXY = 1.0 / xy[0];
- real reciprocalSquareSumX = 1.0 / xx[0];
- real reciprocalSquareSumY = 1.0 / yy[0];
- for (int index = tid; index < width; index += blockSize) {
- prevGradX[index] += output[ty] * grad[ty] *
- (prevOutY[index] * reciprocalXY -
- prevOutX[index] * reciprocalSquareSumX);
- if (input2_height > 1) {
- prevGradY[index] += output[ty] * grad[ty] *
- (prevOutX[index] * reciprocalXY -
- prevOutY[index] * reciprocalSquareSumY);
- } else {
- paddle::paddleAtomicAdd(prevGradY + index, output[ty] * grad[ty] *
- (prevOutX[index] * reciprocalXY -
- prevOutY[index] * reciprocalSquareSumY));
- }
- }
- }
-}
-
-
-void hl_cossim_derivative(real* grad,
- real* output,
- real* prevOutX,
- real* prevOutY,
- real* prevGradX,
- real* prevGradY,
- int width,
- int input1_height,
- int input2_height,
- real scale) {
- CHECK_NOTNULL(grad);
- CHECK_NOTNULL(output);
- CHECK_NOTNULL(prevOutX);
- CHECK_NOTNULL(prevOutY);
- CHECK_NOTNULL(prevGradX);
- CHECK_NOTNULL(prevGradY);
- const int blockSize = 256;
- dim3 threads(blockSize, 1);
- dim3 grid(1, input1_height);
- KeCosSimDerivative<<>>
- (grad, output, prevOutX, prevOutY, prevGradX, prevGradY, width,
- input1_height, input2_height, scale);
- CHECK_SYNC("hl_cossim_derivate failed");
-}
-
__global__ void KeMatrixAddSharedBias(real* A,
real* B,
const int channel,
diff --git a/paddle/function/BufferArg.h b/paddle/function/BufferArg.h
index 349b21e7e64064804c5d0ee26e82698925832c35..0dc7792f646457c22ee4791f18814afaa3809f7b 100644
--- a/paddle/function/BufferArg.h
+++ b/paddle/function/BufferArg.h
@@ -190,7 +190,7 @@ public:
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
bufferType_ = TENSOR_SEQUENCE_ID;
CHECK_EQ(shape_.ndims(), 1UL);
- CHECK_GT(shape_[0], 1UL);
+ CHECK_GE(shape_[0], 1UL);
numSeqs_ = shape_[0] - 1;
}
@@ -226,7 +226,8 @@ public:
SequenceArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
- : BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {
+ : BufferArg(valueType, shape, argType),
+ startPositions_(TensorShape({shape[0]})) {
bufferType_ = TENSOR_SEQUENCE_DATA;
}
diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt
index fae3b7b20a70b56dc44ea2df637281afe01a7e5a..1522510e8bb9816cb468fcf406e22560163950cc 100644
--- a/paddle/function/CMakeLists.txt
+++ b/paddle/function/CMakeLists.txt
@@ -27,6 +27,7 @@ if(WITH_TESTING)
add_simple_unittest(ContextProjectionOpTest)
add_simple_unittest(PadOpTest)
add_simple_unittest(MulOpTest)
+ add_simple_unittest(CosSimOpTest)
endif()
endif()
diff --git a/paddle/function/ContextProjectionOp.cpp b/paddle/function/ContextProjectionOp.cpp
index 6cd4e4abee8fccf3a4745b0bfc6701df4ddfa5c0..b87750b74247bd0eb822340bc5a85d41b86ecec2 100644
--- a/paddle/function/ContextProjectionOp.cpp
+++ b/paddle/function/ContextProjectionOp.cpp
@@ -108,26 +108,23 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
- CHECK(1 == inputs.size() || 2 == inputs.size());
- CHECK_EQ((size_t)1, outputs.size());
+ CHECK(1UL == inputs.size() || 2UL == inputs.size());
+ CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto val_seqs = dynamic_cast(inputs[0]);
auto out_seq = dynamic_cast(outputs[0]);
CHECK(out_seq.data() && val_seqs.data() && val_seqs.getSequenceId().data());
- CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
- CHECK_EQ(val_seqs.shape().ndims(), (size_t)2);
- CHECK_EQ(val_seqs.getSequenceId().shape().ndims(), (size_t)1);
- if (2 == inputs.size()) {
- CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
- }
+ CHECK_EQ(out_seq.shape().ndims(), 2UL);
+ CHECK_EQ(val_seqs.shape().ndims(), 2UL);
/// dim of output = dim of input * context_length
CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_);
/// input and output has the same batch_size
CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]);
- /// dim of input == dim of weight
- if (2 == inputs.size()) {
+ if (2UL == inputs.size()) {
+ CHECK_EQ(inputs[1].shape().ndims(), 2UL);
+ /// dim of input == dim of weight
CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]);
}
@@ -135,10 +132,11 @@ public:
auto out_mat = out_seq.matrix();
const auto in_mat = val_seqs.matrix();
const auto w_mat =
- (2 == inputs.size())
+ (2UL == inputs.size() && inputs[1].data())
? inputs[1].matrix()
: typename Tensor::Matrix(nullptr, 0, 0);
const auto seq_vec = val_seqs.getSequenceId().vector();
+
ContextProjectionForward(out_mat,
in_mat,
w_mat,
@@ -235,36 +233,40 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
- CHECK_EQ((size_t)1, inputs.size());
- CHECK_EQ((size_t)2, outputs.size());
+ CHECK_EQ(1UL, inputs.size());
+ CHECK(1UL == outputs.size() || 2UL == outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto in_seq = dynamic_cast(inputs[0]);
auto out_seq = dynamic_cast(outputs[0]);
CHECK(in_seq.data() && in_seq.getSequenceId().data());
- CHECK_EQ(in_seq.shape().ndims(), (size_t)2);
- CHECK_EQ(in_seq.getSequenceId().shape().ndims(), (size_t)1);
- CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
- CHECK_EQ(out_seq.getSequenceId().shape().ndims(), (size_t)1);
- CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
+ CHECK_EQ(in_seq.shape().ndims(), 2UL);
+ CHECK_EQ(out_seq.shape().ndims(), 2UL);
+ CHECK_EQ(out_seq.getSequenceId().shape().ndims(), 1UL);
- /// dim of input grad == dim of weight
- CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
/// input and output grad has the same batch_size
CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]);
/// dim of output grad = dim of input grad * context_length
CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_);
CHECK_EQ(out_seq.getArgType(), ADD_TO);
- CHECK_EQ(outputs[1].getArgType(), ADD_TO);
+
+ if (2UL == outputs.size()) {
+ CHECK_EQ(outputs[1].shape().ndims(), 2UL);
+ /// dim of input grad == dim of weight
+ CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
+ CHECK_EQ(outputs[1].getArgType(), ADD_TO);
+ }
const auto seq_vec = in_seq.getSequenceId().vector();
const auto out_grad_mat = in_seq.matrix();
auto in_grad_mat =
!out_seq.data() ? typename Tensor::Matrix(nullptr, 0, 0)
: out_seq.matrix();
- auto w_grad_mat = !outputs[1].data()
- ? typename Tensor::Matrix(nullptr, 0, 0)
- : outputs[1].matrix();
+ auto w_grad_mat =
+ (2UL == outputs.size() && outputs[1].data())
+ ? outputs[1].matrix()
+ : typename Tensor::Matrix(nullptr, 0, 0);
+
ContextProjectionBackward(out_grad_mat,
in_grad_mat,
w_grad_mat,
@@ -304,17 +306,17 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
- CHECK_EQ(1, static_cast(inputs.size()));
- CHECK_EQ(1, static_cast(outputs.size()));
+ CHECK_EQ(1UL, inputs.size());
+ CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto in_seq = dynamic_cast(inputs[0]);
const auto out_seq = dynamic_cast(outputs[0]);
CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceId().data());
- CHECK_EQ(static_cast(out_seq.shape().ndims()), 2);
- CHECK_EQ(static_cast(in_seq.shape().ndims()), 2);
- CHECK_EQ(static_cast(in_seq.getSequenceId().shape().ndims()), 1);
+ CHECK_EQ(out_seq.shape().ndims(), 2UL);
+ CHECK_EQ(in_seq.shape().ndims(), 2UL);
+ CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL);
/// output layer grad dim == input layer grad dim * context_length_
CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_);
/// input and output has the same batch_size
@@ -355,14 +357,14 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
- CHECK_EQ(1, static_cast(inputs.size()));
- CHECK_EQ(1, static_cast(outputs.size()));
+ CHECK_EQ(1UL, inputs.size());
+ CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here";
const auto in_seq = dynamic_cast(inputs[0]);
CHECK(in_seq.data() && in_seq.getSequenceId().data() && outputs[0].data());
- CHECK_EQ(static_cast(outputs[0].shape().ndims()), 2);
- CHECK_EQ(static_cast(in_seq.shape().ndims()), 2);
- CHECK_EQ(static_cast(in_seq.getSequenceId().shape().ndims()), 1);
+ CHECK_EQ(outputs[0].shape().ndims(), 2UL);
+ CHECK_EQ(in_seq.shape().ndims(), 2UL);
+ CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL);
CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]);
/// output layer grad dim == weight dim * context_length_
CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_);
diff --git a/paddle/function/ContextProjectionOp.h b/paddle/function/ContextProjectionOp.h
index 2bdd47e4e9b02483c2c5af82bf00c4e55d68f93e..6f7d936379a5378e6fd85dd86618d1b6094bd14f 100644
--- a/paddle/function/ContextProjectionOp.h
+++ b/paddle/function/ContextProjectionOp.h
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
-
#include "Function.h"
namespace paddle {
diff --git a/paddle/function/ContextProjectionOpTest.cpp b/paddle/function/ContextProjectionOpTest.cpp
index c9db2ff8008e0bb0fa04370fb7b3ecd7641d2062..0f5d6a848d406d14984a0b6edad8192dab42e88b 100644
--- a/paddle/function/ContextProjectionOpTest.cpp
+++ b/paddle/function/ContextProjectionOpTest.cpp
@@ -28,55 +28,26 @@ void testMatrixProjectionForward(int context_start,
std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false;
- FunctionCompare compare("ContextProjectionForward",
- FuncConfig()
- .set("context_length", context_length)
- .set("context_start", context_start)
- .set("begin_pad", std::max(0, -context_start)));
-
- CpuMatrix cpu_in(batch_size, input_dim);
- cpu_in.randomizeUniform();
- GpuMatrix gpu_in(batch_size, input_dim);
- gpu_in.copyFrom(cpu_in);
- auto cpu_weight =
- is_padding ? std::make_shared(pad, input_dim) : nullptr;
- auto gpu_weight =
- is_padding ? std::make_shared(pad, input_dim) : nullptr;
- if (is_padding) {
- cpu_weight->randomizeUniform();
- gpu_weight->copyFrom(*cpu_weight);
+ FunctionCompare test("ContextProjectionForward",
+ FuncConfig()
+ .set("context_length", context_length)
+ .set("context_start", context_start)
+ .set("begin_pad", std::max(0, -context_start)));
+
+ // prepare input arguments
+ test.addSequence(SequenceIdArg(TensorShape{batch_size}));
+ test.addInputs(
+ SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim}));
+ if (is_padding) { // weight
+ test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{pad, input_dim}));
}
- IVectorPtr cpu_seq;
- generateSequenceStartPositions(batch_size, cpu_seq);
- IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true);
- gpu_seq->copyFrom(*cpu_seq);
-
- CpuMatrix cpu_out(batch_size, input_dim * context_length);
- GpuMatrix gpu_out(batch_size, input_dim * context_length);
- cpu_out.randomizeUniform();
- gpu_out.copyFrom(cpu_out);
-
- BufferArgs cpu_inputs;
- BufferArgs cpu_outputs;
- cpu_inputs.addArg(cpu_in, *cpu_seq);
- if (cpu_weight) {
- cpu_inputs.addArg(*cpu_weight, *cpu_seq);
- }
- cpu_outputs.addArg(cpu_out, *cpu_seq, ADD_TO);
-
- compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs);
+ test.addOutputs(
+ SequenceArg(VALUE_TYPE_FLOAT,
+ TensorShape{batch_size, input_dim * context_length}),
+ ADD_TO);
- BufferArgs gpu_inputs;
- BufferArgs gpu_outputs;
- gpu_inputs.addArg(gpu_in, *gpu_seq);
- if (gpu_weight) {
- gpu_inputs.addArg(*gpu_weight, *gpu_seq);
- }
- gpu_outputs.addArg(gpu_out, *gpu_seq, ADD_TO);
-
- compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs);
-
- autotest::TensorCheckEqual(cpu_out, gpu_out);
+ // run Function
+ test.run();
}
void testMatrixProjectionBackward(int context_start,
@@ -88,63 +59,31 @@ void testMatrixProjectionBackward(int context_start,
std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false;
- FunctionCompare compare("ContextProjectionBackward",
- FuncConfig()
- .set("context_length", context_length)
- .set("context_start", context_start)
- .set("begin_pad", std::max(0, -context_start))
- .set("is_padding", is_padding)
- .set("total_pad", pad));
-
- CpuMatrix cpu_in_grad(batch_size, input_dim);
- cpu_in_grad.randomizeUniform();
- GpuMatrix gpu_in_grad(batch_size, input_dim);
- gpu_in_grad.copyFrom(cpu_in_grad);
-
- CpuMatrix cpu_out_grad(batch_size, input_dim * context_length);
- cpu_out_grad.randomizeUniform();
- GpuMatrix gpu_out_grad(batch_size, input_dim * context_length);
- gpu_out_grad.copyFrom(cpu_out_grad);
-
- IVectorPtr cpu_seq;
- generateSequenceStartPositions(batch_size, cpu_seq);
- IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true);
- gpu_seq->copyFrom(*cpu_seq);
-
- auto cpu_w_grad =
- is_padding ? std::make_shared(pad, input_dim) : nullptr;
- auto gpu_w_grad =
- is_padding ? std::make_shared(pad, input_dim) : nullptr;
- if (is_padding) {
- cpu_w_grad->randomizeUniform();
- gpu_w_grad->copyFrom(*cpu_w_grad);
+ FunctionCompare test("ContextProjectionBackward",
+ FuncConfig()
+ .set("context_length", context_length)
+ .set("context_start", context_start)
+ .set("begin_pad", std::max(0, -context_start))
+ .set("is_padding", is_padding)
+ .set("total_pad", pad));
+
+ // prepare input arguments
+ test.addSequence(SequenceIdArg(TensorShape{batch_size}));
+ test.addInputs(SequenceArg(
+ VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim * context_length}));
+ test.addOutputs(
+ SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim}),
+ ADD_TO);
+ if (is_padding) { // weight
+ test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{pad, input_dim}),
+ ADD_TO);
}
- BufferArgs cpu_inputs;
- BufferArgs cpu_outputs;
- cpu_inputs.addArg(cpu_out_grad, *cpu_seq);
- cpu_outputs.addArg(cpu_in_grad, *cpu_seq, ADD_TO);
- cpu_outputs.addArg(
- cpu_w_grad ? *cpu_w_grad : CpuMatrix(nullptr, 0, input_dim), ADD_TO);
-
- compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs);
-
- BufferArgs gpu_inputs;
- BufferArgs gpu_outputs;
- gpu_inputs.addArg(gpu_out_grad, *gpu_seq);
- gpu_outputs.addArg(gpu_in_grad, *gpu_seq, ADD_TO);
- gpu_outputs.addArg(
- gpu_w_grad ? *gpu_w_grad : GpuMatrix(nullptr, 0, input_dim), ADD_TO);
-
- compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs);
-
- autotest::TensorCheckErr(cpu_in_grad, gpu_in_grad);
- if (is_padding) {
- autotest::TensorCheckErr(*cpu_w_grad, *gpu_w_grad);
- }
+ // run Function
+ test.run();
}
-TEST(ContextProjection, projection) {
+TEST(ContextProjection, Projection) {
for (auto context_start : {-5, -3, -1, 0, 3}) {
for (auto context_length : {1, 2, 5, 7}) {
for (auto trainable_padding : {false, true}) {
diff --git a/paddle/function/CosSimOp.cpp b/paddle/function/CosSimOp.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..7ece7b2dfedaf460741c97b5a700eb632d85cabc
--- /dev/null
+++ b/paddle/function/CosSimOp.cpp
@@ -0,0 +1,240 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "CosSimOp.h"
+#include "paddle/math/Matrix.h"
+#include "paddle/math/Vector.h"
+
+namespace paddle {
+/**
+ * Cosine Similarity for CpuMatrix
+ *
+ * \param out_mat, output value, size: nSamples * 1.
+ * \param in1_mat, input value 1, size: nSamples * dim.
+ * \param in2_mat, input value 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
+ * \param scale, default 1.0
+ *
+ */
+template <>
+void CosSimForward(CpuMatrix& out_mat,
+ const CpuMatrix& in1_mat,
+ const CpuMatrix& in2_mat,
+ real scale) {
+ CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
+ size_t num_samples = out_mat.getHeight();
+ size_t dim = in1_mat.getWidth();
+ /// column vector [nSamples, 1]
+ real* out = out_mat.getData();
+ const real* x = in1_mat.getData();
+ const real* y = in2_mat.getData();
+
+ /// in2 might only have one row or full rows
+ CHECK(in2_mat.getHeight() == 1LU || in2_mat.getHeight() == num_samples);
+ size_t inc = (in2_mat.getHeight() == 1LU) ? 0 : dim;
+ for (size_t i = 0; i < num_samples; ++i, x += dim, y += inc) {
+ real square_sum_x = 0;
+ real square_sum_y = 0;
+ real xy = 0;
+ for (size_t j = 0; j < dim; ++j) {
+ square_sum_x += x[j] * x[j];
+ square_sum_y += y[j] * y[j];
+ xy += x[j] * y[j];
+ }
+ CHECK(square_sum_x > 0 && square_sum_y > 0);
+ out[i] = scale * xy / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
+ }
+}
+
+/**
+ * Cosine Similarity
+ * for each row i,
+ * out[i] = scale * cos(input1[i], input2[i])
+ * = scale * /sqrt(|input1[i]|^2 * |input2[i]|^2)
+ * when input2 only has one row, then for each row i,
+ * out[i] = cos(input1[i], input2[0])
+ *
+ * \param inputs[0] input matrix 1, size: nSamples * dim.
+ * \param inputs[1] input matrix 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
+ * \param outputs[0] output matrix, size : nSamples * 1.
+ */
+
+template
+class CosSimForwardFunc : public FunctionBase {
+ void init(const FuncConfig& config) override {
+ scale_ = config.get("scale");
+ }
+
+ void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+ CHECK_EQ(inputs.size(), 2UL);
+ CHECK_EQ(outputs.size(), 1UL);
+
+ CHECK_EQ(inputs[0].shape().ndims(), 2UL);
+ CHECK_EQ(inputs[1].shape().ndims(), 2UL);
+ CHECK_EQ(outputs[0].shape().ndims(), 2UL);
+
+ CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
+ CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
+ CHECK_EQ(outputs[0].shape()[1], 1UL);
+
+ CHECK(outputs[0].data() && inputs[0].data() && inputs[1].data());
+
+ CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
+ auto out_mat = outputs[0].matrix();
+ const auto in1_mat = inputs[0].matrix();
+ const auto in2_mat = inputs[1].matrix();
+
+ CosSimForward(out_mat, in1_mat, in2_mat, scale_);
+ }
+
+private:
+ real scale_;
+};
+
+/**
+ * Cosine Similarity Derivative for CpuMatrix
+ *
+ * \param in1_grad forward input grad 1, size: nSamples * dim.
+ * \param in2_grad forward input grad 2,
+ * size: n2 * dim (n2 == 1 or n2 == nSamples).
+ *
+ * \param out_grad backward loss output grad, size : nSamples * 1.
+ * \param out_val forward output value, size: nSamples * 1.
+ * \param in1_val forward input value 1, size: nSamples * dim.
+ * \param in2_val forward input value 2,
+ * size: n2 * dim (n2 == 1 or n2 == nSamples).
+ * \param scale, default 1.0
+ */
+template <>
+void CosSimBackward(const CpuMatrix& out_grad,
+ const CpuMatrix& out_val,
+ const CpuMatrix& in1_val,
+ const CpuMatrix& in2_val,
+ CpuMatrix& in1_grad,
+ CpuMatrix& in2_grad,
+ real scale) {
+ CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
+ in2_val.getData() && in1_grad.getData() && in2_grad.getData());
+ CHECK_EQ(out_val.useGpu_, false) << "Matrix type are GPU, CPU required";
+
+ const real* grad = out_grad.getData();
+ const real* out = out_val.getData();
+ const real* prev_out_x = in1_val.getData();
+ const real* prev_out_y = in2_val.getData();
+ real* prev_grad_x = in1_grad.getData();
+ real* prev_grad_y = in2_grad.getData();
+
+ size_t num_samples = out_grad.getHeight();
+ size_t dim = in1_val.getWidth();
+ CHECK_EQ(in2_val.getHeight(), in2_grad.getHeight());
+ CHECK(in2_val.getHeight() == 1LU || in2_val.getHeight() == num_samples);
+ size_t inc = (in2_val.getHeight() == 1LU) ? 0 : dim;
+ for (size_t i = 0; i < num_samples; ++i,
+ prev_out_x += dim,
+ prev_out_y += inc,
+ prev_grad_x += dim,
+ prev_grad_y += inc) {
+ real square_sum_x = 0;
+ real square_sum_y = 0;
+ real xy = 0;
+ for (size_t j = 0; j < dim; ++j) {
+ square_sum_x += prev_out_x[j] * prev_out_x[j];
+ square_sum_y += prev_out_y[j] * prev_out_y[j];
+ xy += prev_out_x[j] * prev_out_y[j];
+ }
+ CHECK(square_sum_x > 0 && square_sum_y > 0);
+ if (xy == 0) {
+ real reciprocal =
+ 1.0f / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
+ for (size_t j = 0; j < dim; ++j) {
+ prev_grad_x[j] += scale * grad[i] * prev_out_y[j] * reciprocal;
+ prev_grad_y[j] += scale * grad[i] * prev_out_x[j] * reciprocal;
+ }
+ } else {
+ real reciprocal_xy = 1.0f / xy;
+ real reciprocal_square_sum_x = 1.0f / square_sum_x;
+ real reciprocal_square_sum_y = 1.0f / square_sum_y;
+ for (size_t j = 0; j < dim; ++j) {
+ prev_grad_x[j] +=
+ out[i] * grad[i] * (prev_out_y[j] * reciprocal_xy -
+ prev_out_x[j] * reciprocal_square_sum_x);
+ prev_grad_y[j] +=
+ out[i] * grad[i] * (prev_out_x[j] * reciprocal_xy -
+ prev_out_y[j] * reciprocal_square_sum_y);
+ }
+ }
+ }
+}
+
+/**
+ * Cosine Similarity backward Derivative
+ *
+ * \param outputs[0] forward input grad 1, size: nSamples * dim.
+ * \param outputs[1] forward input grad 2,
+ * size: n2 * dim (n2 == 1 or n2 == nSamples).
+ *
+ * \param inputs[0] backward loss output grad, size : nSamples * 1.
+ * \param inputs[1] forward output value, size: nSamples * 1.
+ * \param inputs[2] forward input value 1, size: nSamples * dim.
+ * \param inputs[3] forward input value 2,
+ * size: n2 * dim (n2 == 1 or n2 == nSamples).
+ */
+template
+class CosSimBackwardFunc : public FunctionBase {
+ void init(const FuncConfig& config) override {
+ scale_ = config.get("scale");
+ }
+
+ void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+ CHECK_EQ(inputs.size(), 4UL);
+ CHECK_EQ(outputs.size(), 2UL);
+ /// dim of out_grad and out_val == 1, column vector
+ CHECK_EQ(inputs[0].shape()[1], 1UL);
+ CHECK_EQ(inputs[1].shape()[1], 1UL);
+ /// nSamples of out_grad == out_val == in_val1 == in_grad1
+ CHECK_EQ(inputs[1].shape()[0], inputs[0].shape()[0]);
+ CHECK_EQ(inputs[0].shape()[0], inputs[0].shape()[0]);
+ CHECK_EQ(outputs[0].shape()[0], inputs[0].shape()[0]);
+ /// dim of in1_val1 == in_val2 == in_grad1 == in_grad2
+ CHECK_EQ(inputs[3].shape()[1], inputs[2].shape()[1]);
+ CHECK_EQ(outputs[0].shape()[1], inputs[2].shape()[1]);
+ CHECK_EQ(outputs[1].shape()[1], inputs[2].shape()[1]);
+
+ CHECK(inputs[0].data() && inputs[1].data() && inputs[2].data() &&
+ inputs[3].data() && outputs[0].data() && outputs[1].data());
+
+ CHECK_EQ(outputs[0].getArgType(), ADD_TO);
+ CHECK_EQ(outputs[1].getArgType(), ADD_TO);
+
+ const auto out_grad = inputs[0].matrix();
+ const auto out_val = inputs[1].matrix();
+ const auto in1_val = inputs[2].matrix();
+ const auto in2_val = inputs[3].matrix();
+ auto in1_grad = outputs[0].matrix();
+ auto in2_grad = outputs[1].matrix();
+
+ CosSimBackward(
+ out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_);
+ }
+
+private:
+ real scale_;
+};
+
+REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc);
+REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc);
+#ifndef PADDLE_ONLY_CPU
+REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc);
+REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc);
+#endif
+} // namespace paddle
diff --git a/paddle/function/CosSimOp.h b/paddle/function/CosSimOp.h
new file mode 100644
index 0000000000000000000000000000000000000000..be73064e6375bf1e6c6a7ca6de52e9b9b755880b
--- /dev/null
+++ b/paddle/function/CosSimOp.h
@@ -0,0 +1,61 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Function.h"
+
+namespace paddle {
+
+/**
+ * \brief Cosine Similarity Forward.
+ * for each row i,
+ * out[i] = scale * cos(in1[i], in2[i])
+ * = scale * \sum_j (in1[i][j] * in2[i][j]) /
+ * sqrt(sum_j (in1[i][j]^2) * sum_j (in2[i][j])^2)
+ *
+ * \param[out] output output value.
+ * \param[in] intput1 input value.
+ * \param[in] intput2 input value.
+ * \param[in] scale default 1.0.
+ *
+ */
+template
+void CosSimForward(typename Tensor::Matrix& output,
+ const typename Tensor::Matrix& input1,
+ const typename Tensor::Matrix& input2,
+ real scale);
+
+/**
+ * \brief Cosine Similarity BackWard for Derivative.
+ *
+ * \param[in] output grad backward loss output grad.
+ * \param[in] output val forward-output value.
+ * \param[in] input val1 forward input value 1.
+ * \param[in] input val2 forward input value 2.
+ * \param[in/out] input grad forward input grad 1.
+ * \param[in/out] input grad forward input grad 2.
+ * \param[in] scale default 1.0.
+ *
+ */
+template
+void CosSimBackward(const typename Tensor::Matrix& out_grad,
+ const typename Tensor::Matrix& out_value,
+ const typename Tensor::Matrix& in1_value,
+ const typename Tensor::Matrix& in2_value,
+ typename Tensor::Matrix& in1_grad,
+ typename Tensor::Matrix& in2_grad,
+ real scale);
+
+} // namespace paddle
diff --git a/paddle/function/CosSimOpGpu.cu b/paddle/function/CosSimOpGpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..1dd733674fa0542c76070955ec63e008b083c7d2
--- /dev/null
+++ b/paddle/function/CosSimOpGpu.cu
@@ -0,0 +1,241 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "hl_base.h"
+#include "hl_device_functions.cuh"
+#include "CosSimOp.h"
+
+namespace paddle {
+
+template
+__global__ void KeCosSim(real* output,
+ const real* input1,
+ const real* input2,
+ int width,
+ int input1_height,
+ int input2_height,
+ real scale) {
+ const int ty = blockIdx.y;
+ int tid = threadIdx.x;
+
+ __shared__ real xx[block_size];
+ __shared__ real yy[block_size];
+ __shared__ real xy[block_size];
+
+ xx[tid] = 0.0;
+ yy[tid] = 0.0;
+ xy[tid] = 0.0;
+ __syncthreads();
+
+ input1 += ty * width;
+ if (input2_height > 1) {
+ input2 += ty * width;
+ }
+ for (int index = tid; index < width; index += block_size) {
+ real x = input1[index];
+ real y = input2[index];
+ xx[tid] += x * x;
+ yy[tid] += y * y;
+ xy[tid] += x * y;
+ }
+ __syncthreads();
+
+ for (int s = block_size / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ xx[tid] += xx[tid + s];
+ yy[tid] += yy[tid + s];
+ xy[tid] += xy[tid + s];
+ }
+ __syncthreads();
+ }
+ if (tid == 0) {
+ output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
+ }
+}
+
+void hlCossim(real* output,
+ const real* input1,
+ const real* input2,
+ size_t width,
+ size_t input1_height,
+ size_t input2_height,
+ real scale) {
+ CHECK_NOTNULL(output);
+ CHECK_NOTNULL(input1);
+ CHECK_NOTNULL(input2);
+ const int block_size = 256;
+ dim3 threads(block_size, 1);
+ dim3 grid(1, input1_height);
+
+ KeCosSim<<>>
+ (output, input1, input2, width, input1_height, input2_height, scale);
+ CHECK_SYNC("hlCossim failed");
+}
+
+template <>
+void CosSimForward(GpuMatrix& out_mat,
+ const GpuMatrix& in1_mat,
+ const GpuMatrix& in2_mat,
+ real scale) {
+ CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
+ CHECK(in1_mat.useGpu_ == true && in2_mat.useGpu_ == true)
+ << "Matrix type are not GPU";
+
+ size_t num_samples = out_mat.getHeight();
+ size_t dim = in1_mat.getWidth();
+ real* out = out_mat.getData();
+ const real* x = in1_mat.getData();
+ const real* y = in2_mat.getData();
+ hlCossim(out, x, y, dim, in1_mat.getHeight(), in2_mat.getHeight(), scale);
+}
+
+template
+__global__ void KeCosSimDerivative(const real* grad,
+ const real* output,
+ const real* prev_out_x,
+ const real* prev_out_y,
+ real* prev_grad_x,
+ real* prev_grad_y,
+ size_t width,
+ size_t input1_height,
+ size_t input2_height,
+ real scale) {
+ const int ty = blockIdx.y;
+ int tid = threadIdx.x;
+
+ __shared__ real xx[block_size];
+ __shared__ real yy[block_size];
+ __shared__ real xy[block_size];
+
+ xx[tid] = 0.0;
+ yy[tid] = 0.0;
+ xy[tid] = 0.0;
+ __syncthreads();
+
+ prev_out_x += ty * width;
+ prev_grad_x += ty * width;
+ if (input2_height > 1) {
+ prev_out_y += ty * width;
+ prev_grad_y += ty * width;
+ }
+ for (int index = tid; index < width; index += block_size) {
+ real x = prev_out_x[index];
+ real y = prev_out_y[index];
+ xx[tid] += x * x;
+ yy[tid] += y * y;
+ xy[tid] += x * y;
+ }
+ __syncthreads();
+
+ for (int s = block_size / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ xx[tid] += xx[tid + s];
+ yy[tid] += yy[tid + s];
+ xy[tid] += xy[tid + s];
+ }
+ __syncthreads();
+ }
+ if (xy[0] == 0) {
+ real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
+ for (int index = tid; index < width; index += block_size) {
+ prev_grad_x[index] +=
+ scale * grad[ty] * prev_out_y[index] * reciprocal;
+ if (input2_height > 1) {
+ prev_grad_y[index] +=
+ scale * grad[ty] * prev_out_x[index] * reciprocal;
+ } else {
+ paddle::paddleAtomicAdd(prev_grad_y + index,
+ scale * grad[ty] * prev_out_x[index] * reciprocal);
+ }
+ }
+ } else {
+ real reciprocalXY = 1.0 / xy[0];
+ real reciprocalSquareSumX = 1.0 / xx[0];
+ real reciprocalSquareSumY = 1.0 / yy[0];
+ for (int index = tid; index < width; index += block_size) {
+ prev_grad_x[index] += output[ty] * grad[ty] *
+ (prev_out_y[index] * reciprocalXY -
+ prev_out_x[index] * reciprocalSquareSumX);
+ if (input2_height > 1) {
+ prev_grad_y[index] += output[ty] * grad[ty] *
+ (prev_out_x[index] * reciprocalXY -
+ prev_out_y[index] * reciprocalSquareSumY);
+ } else {
+ paddle::paddleAtomicAdd(prev_grad_y + index, output[ty] * grad[ty] *
+ (prev_out_x[index] * reciprocalXY -
+ prev_out_y[index] * reciprocalSquareSumY));
+ }
+ }
+ }
+}
+
+void hlCossimDerivative(const real* grad,
+ const real* output,
+ const real* prev_out_x,
+ const real* prev_out_y,
+ real* prev_grad_x,
+ real* prev_grad_y,
+ size_t width,
+ size_t input1_height,
+ size_t input2_height,
+ real scale) {
+ CHECK_NOTNULL(grad);
+ CHECK_NOTNULL(output);
+ CHECK_NOTNULL(prev_out_x);
+ CHECK_NOTNULL(prev_out_y);
+ CHECK_NOTNULL(prev_grad_x);
+ CHECK_NOTNULL(prev_grad_y);
+ const int block_size = 256;
+ dim3 threads(block_size, 1);
+ dim3 grid(1, input1_height);
+ KeCosSimDerivative<<>>
+ (grad, output, prev_out_x, prev_out_y, prev_grad_x, prev_grad_y, width,
+ input1_height, input2_height, scale);
+ CHECK_SYNC("hlCossimDerivate failed");
+}
+
+template <>
+void CosSimBackward(const GpuMatrix& out_grad,
+ const GpuMatrix& out_val,
+ const GpuMatrix& in1_val,
+ const GpuMatrix& in2_val,
+ GpuMatrix& in1_grad,
+ GpuMatrix& in2_grad,
+ real scale) {
+ CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
+ in2_val.getData() && in1_grad.getData() && in2_grad.getData());
+ CHECK(out_grad.useGpu_ && out_val.useGpu_ && in1_val.useGpu_
+ && in2_val.useGpu_ && in1_grad.useGpu_ && in2_grad.useGpu_)
+ << "Matrix types are not equally GPU";
+
+ size_t dim = in1_val.getWidth();
+ const real* grad = out_grad.getData();
+ const real* out = out_val.getData();
+ const real* prev_out_x = in1_val.getData();
+ const real* prev_out_y = in2_val.getData();
+ real* prev_grad_x = in1_grad.getData();
+ real* prev_grad_y = in2_grad.getData();
+ hlCossimDerivative(grad,
+ out,
+ prev_out_x,
+ prev_out_y,
+ prev_grad_x,
+ prev_grad_y,
+ dim,
+ in1_val.getHeight(),
+ in2_val.getHeight(),
+ scale);
+}
+
+} // namespace paddle
diff --git a/paddle/function/CosSimOpTest.cpp b/paddle/function/CosSimOpTest.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..48c815f027161b48c17ce654ab819156fd856199
--- /dev/null
+++ b/paddle/function/CosSimOpTest.cpp
@@ -0,0 +1,64 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include "FunctionTest.h"
+#include "paddle/math/Matrix.h"
+
+using namespace paddle; // NOLINT
+
+void testCosSimForward(size_t height_x,
+ size_t height_y,
+ size_t width,
+ real scale) {
+ FunctionCompare test("CosSimForward", FuncConfig().set("scale", scale));
+ // prepare input arguments
+ test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
+ test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
+ test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}),
+ ASSIGN_TO);
+ // run Function
+ test.run();
+}
+
+void testCosSimBackward(size_t height_x,
+ size_t height_y,
+ size_t width,
+ real scale) {
+ FunctionCompare test("CosSimBackward", FuncConfig().set("scale", scale));
+ // prepare input arguments
+ test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
+ test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
+ test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
+ test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
+ test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}),
+ ADD_TO);
+ test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}),
+ ADD_TO);
+ // run Function
+ test.run();
+}
+
+TEST(Matrix, cosSim) {
+ for (auto height_x : {10, 100, 1000}) {
+ for (auto height_y : {1, height_x}) {
+ for (auto width : {10, 100, 1000}) {
+ for (auto scale : {1.0, 2.0}) {
+ testCosSimForward(height_x, height_y, width, scale);
+ testCosSimBackward(height_x, height_y, width, scale);
+ }
+ }
+ }
+ }
+}
diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h
index 00f59f97d4c8c1076abe00866b786615a9801a5d..0cfafdb27f55a3e6617d31a968d2a05fc77f5b46 100644
--- a/paddle/function/FunctionTest.h
+++ b/paddle/function/FunctionTest.h
@@ -69,6 +69,54 @@ public:
gpuMemory_.back()->getBuf(), input.valueType(), input.shape()));
}
+ // assume one copy of sequence is shared by different SequenceArgs
+ void addSequence(const SequenceIdArg& input) {
+ CHECK_EQ(input.shape().ndims(), 1UL);
+ size_t batchSize = input.shape()[0];
+ size_t numSeqs = batchSize / 10 + 1;
+ size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32);
+ cpuMemory_.emplace_back(std::make_shared(sizeId));
+ gpuMemory_.emplace_back(std::make_shared(sizeId));
+ cpuSeq_ = std::make_shared(cpuMemory_.back()->getBuf(),
+ TensorShape{numSeqs + 1});
+ gpuSeq_ = std::make_shared(gpuMemory_.back()->getBuf(),
+ TensorShape{numSeqs + 1});
+ /// init sequence Id
+ initArg(*cpuSeq_, batchSize);
+
+ // todo(tianbing), delete it
+ CHECK_EQ(cpuSeq_->shape().getElements(), cpuSeq_->numSeqs() + 1);
+
+ CpuIVector cpuSeq(cpuSeq_->shape().getElements(), (int*)cpuSeq_->data());
+ GpuIVector gpuSeq(gpuSeq_->shape().getElements(), (int*)gpuSeq_->data());
+ gpuSeq.copyFrom(cpuSeq);
+ }
+
+ void addInputs(const SequenceArg& input) {
+ CHECK_EQ(input.shape().ndims(), 2UL);
+ size_t batchSize = input.shape()[0];
+ if (!cpuSeq_ || !gpuSeq_) { // sequence not exist
+ addSequence(SequenceIdArg(TensorShape{batchSize}));
+ }
+
+ size_t size =
+ input.shape().getElements() * sizeOfValuType(input.valueType());
+ cpuMemory_.emplace_back(std::make_shared(size));
+ gpuMemory_.emplace_back(std::make_shared(size));
+
+ /// SequenceArg
+ cpuInputs_.emplace_back(
+ std::make_shared(cpuMemory_.back()->getBuf(),
+ input.valueType(),
+ input.shape(),
+ *cpuSeq_));
+ gpuInputs_.emplace_back(
+ std::make_shared(gpuMemory_.back()->getBuf(),
+ input.valueType(),
+ input.shape(),
+ *gpuSeq_));
+ }
+
// output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size =
@@ -116,24 +164,31 @@ public:
std::make_shared(*gpuSparse_, argType));
}
- void addInputs(const SequenceArg& input) {
- size_t batchSize = input.shape()[0];
- size_t numSeqs = batchSize / 10 + 1;
-
- size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32);
- cpuMemory_.emplace_back(std::make_shared(sizeId));
- gpuMemory_.emplace_back(std::make_shared(sizeId));
-
- TensorShape seqsId({numSeqs + 1});
- // void* cpuBuffer = cpuMemory_.back()->getBuf();
- // void* gpuBuffer = gpuMemory_.back()->getBuf();
+ void addOutputs(const SequenceArg& output, ArgType argType = ASSIGN_TO) {
+ CHECK_EQ(output.shape().ndims(), 2UL);
+ size_t batchSize = output.shape()[0];
+ if (!cpuSeq_ || !gpuSeq_) { // sequence not exist
+ addSequence(SequenceIdArg(TensorShape{batchSize}));
+ }
size_t size =
- input.shape().getElements() * sizeOfValuType(input.valueType());
+ output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared(size));
gpuMemory_.emplace_back(std::make_shared(size));
- // TODO: need be implemented.
+ /// SequenceArg
+ cpuOutputs_.emplace_back(
+ std::make_shared(cpuMemory_.back()->getBuf(),
+ output.valueType(),
+ output.shape(),
+ *cpuSeq_,
+ argType));
+ gpuOutputs_.emplace_back(
+ std::make_shared(gpuMemory_.back()->getBuf(),
+ output.valueType(),
+ output.shape(),
+ *gpuSeq_,
+ argType));
}
void addInputs(const SparseMatrixArg& input) {
@@ -193,14 +248,44 @@ public:
std::shared_ptr getGpuFunction() const { return gpuFunc_; }
protected:
+ // only init cpu argument, gpu argument copy from cpu argument.
+ void initArg(BufferArg& arg) {
+ CpuVector vector(arg.shape().getElements(), (real*)arg.data());
+ vector.uniform(0.001, 1);
+ }
+
+ void initArg(SequenceArg& arg) {
+ /// init only matrix
+ CpuVector vector(arg.shape().getElements(), (real*)arg.data());
+ vector.uniform(0.001, 1);
+ }
+
+ void initArg(SequenceIdArg& arg, size_t batchSize) {
+ size_t numSeqs = arg.numSeqs();
+ int* buf = reinterpret_cast(arg.data());
+ int pos = 0;
+ size_t maxLen = 2 * batchSize / numSeqs;
+ for (int i = 0; i < (int)numSeqs; ++i) {
+ int len = 1 + uniformRandom(std::min(
+ maxLen, batchSize - pos - numSeqs + i));
+ buf[i] = pos;
+ pos += len;
+ VLOG(1) << " len=" << len;
+ }
+ buf[numSeqs] = batchSize;
+ }
+
void initInputs() {
for (size_t i = 0; i < cpuInputs_.size(); i++) {
if (cpuInputs_[i]->isSparseArg()) {
continue; /// sparse matrix already init
}
- initArg(*cpuInputs_[i]);
-
+ if (cpuInputs_[i]->isSequenceArg()) {
+ initArg(dynamic_cast(*cpuInputs_[i]));
+ } else {
+ initArg(*cpuInputs_[i]);
+ }
// TODO: Need a BufferCopy used to copy from one BufferArg to another.
CpuVector cpuVector(cpuInputs_[i]->shape().getElements(),
(real*)cpuInputs_[i]->data());
@@ -217,7 +302,11 @@ protected:
continue; /// sparse matrix already init
}
- initArg(*cpuOutputs_[i]);
+ if (cpuOutputs_[i]->isSequenceArg()) {
+ initArg(dynamic_cast(*cpuOutputs_[i]));
+ } else {
+ initArg(*cpuOutputs_[i]);
+ }
// TODO: Need a BufferCopy used to copy from one BufferArg to another.
CpuVector cpuVector(cpuOutputs_[i]->shape().getElements(),
@@ -241,28 +330,6 @@ protected:
}
}
- // only init cpu argument, gpu argument copy from cpu argument.
- void initArg(BufferArg& arg) {
- CpuVector vector(arg.shape().getElements(), (real*)arg.data());
- vector.uniform(0.001, 1);
- }
-
- void initArg(SequenceIdArg& arg, size_t batchSize) {
- size_t numSeqs = arg.numSeqs();
- int* buf = reinterpret_cast(arg.data());
- int pos = 0;
- size_t maxLen = 2 * batchSize / numSeqs;
- for (int i = 0; i < (int)numSeqs; ++i) {
- int len = uniformRandom(
- std::min(maxLen, batchSize - pos - numSeqs + i)) +
- 1;
- buf[i] = pos;
- pos += len;
- VLOG(1) << " len=" << len;
- }
- buf[numSeqs] = batchSize;
- }
-
protected:
std::shared_ptr cpuFunc_;
std::shared_ptr gpuFunc_;
@@ -274,6 +341,8 @@ protected:
std::vector gpuOutputs_;
std::shared_ptr cpuSparse_;
std::shared_ptr gpuSparse_;
+ std::shared_ptr cpuSeq_;
+ std::shared_ptr gpuSeq_;
};
} // namespace paddle
diff --git a/paddle/function/MulOpTest.cpp b/paddle/function/MulOpTest.cpp
index 158c3c90983b12c352765479006669c5c9e5a8aa..8748eb0d79fa0fcb0935eac5bb37b44274128aa0 100644
--- a/paddle/function/MulOpTest.cpp
+++ b/paddle/function/MulOpTest.cpp
@@ -60,7 +60,7 @@ TEST(MulOp, DDDMatrixMul) {
if (transa && transb) {
continue;
}
- VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
+ VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " transa=" << transa << " transb=" << transb
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
@@ -104,7 +104,7 @@ TEST(MuLOp, DSparseDMul) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSR}) {
- VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
+ VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
@@ -150,7 +150,7 @@ TEST(MulOp, DDSparseMul) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSR, SPARSE_CSC}) {
- VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
+ VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
@@ -197,7 +197,7 @@ TEST(MulOp, SparseDDMul) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSC, SPARSE_CSR}) {
- VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
+ VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
diff --git a/paddle/gserver/dataproviders/PyDataProvider2.cpp b/paddle/gserver/dataproviders/PyDataProvider2.cpp
index c26e242534f2afcff396762adb085bf99303e2b5..b8079dc0796d0e300e65ac6b6b8d3bc826b1e504 100644
--- a/paddle/gserver/dataproviders/PyDataProvider2.cpp
+++ b/paddle/gserver/dataproviders/PyDataProvider2.cpp
@@ -647,7 +647,7 @@ public:
DataBatch& gpuBatch = *batch;
std::vector& gpuArguments = gpuBatch.getStreams();
gpuArguments.resize(cpuArguments.size());
- gpuBatch.setSize(size);
+ gpuBatch.setSize(bsize);
for (size_t i = 0; i < headers_.size(); ++i) {
gpuArguments[i].resizeAndCopyFrom(
cpuArguments[i], useGpu_, HPPL_STREAM_1);
diff --git a/paddle/gserver/layers/CosSimLayer.cpp b/paddle/gserver/layers/CosSimLayer.cpp
index 254120443dc3d41bf2422be2e88cb376d70c93d4..a6c0300acf6752a3536e7939577b561fd97d1eb8 100644
--- a/paddle/gserver/layers/CosSimLayer.cpp
+++ b/paddle/gserver/layers/CosSimLayer.cpp
@@ -26,15 +26,23 @@ bool CosSimLayer::init(const LayerMap& layerMap,
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2LU);
+
+ createFunction(forward_,
+ "CosSimForward",
+ FuncConfig().set("scale", (real)config_.cos_scale()));
+ createFunction(backward_,
+ "CosSimBackward",
+ FuncConfig().set("scale", (real)config_.cos_scale()));
+
return true;
}
void CosSimLayer::forward(PassType passType) {
Layer::forward(passType);
-
/* malloc memory for the output_ if necessary */
int batchSize = getInputValue(0)->getHeight();
int size = getSize();
+ CHECK_EQ(forward_.size(), 1) << "Only one forward function needed";
{
REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str());
@@ -42,26 +50,43 @@ void CosSimLayer::forward(PassType passType) {
}
MatrixPtr outV = getOutputValue();
-
/* activation */ {
REGISTER_TIMER_INFO("CosFwAtvTimer", getName().c_str());
MatrixPtr prevOut1 = getInputValue(0);
MatrixPtr prevOut2 = getInputValue(1);
- outV->cosSim(*prevOut1, *prevOut2, config_.cos_scale());
+
+ CHECK(outV && prevOut1 && prevOut2);
+ BufferArgs inputs;
+ BufferArgs outputs;
+ inputs.addArg(*prevOut1);
+ inputs.addArg(*prevOut2);
+ outputs.addArg(*outV, ASSIGN_TO);
+ forward_[0]->calc(inputs, outputs);
}
}
void CosSimLayer::backward(const UpdateCallback& callback) {
/* activation */ {
REGISTER_TIMER_INFO("CosBpAtvTimer", getName().c_str());
- MatrixPtr outG = this->getOutputGrad();
-
- outG->cosSimDerivative(*this->getOutputValue(),
- *getInputValue(0),
- *getInputValue(1),
- *getInputGrad(0),
- *getInputGrad(1),
- config_.cos_scale());
+ CHECK_EQ(backward_.size(), 1) << "Only one backward function needed";
+
+ const auto outG = this->getOutputGrad();
+ const auto outV = this->getOutputValue();
+ const auto inV1 = this->getInputValue(0);
+ const auto inV2 = this->getInputValue(1);
+ auto inG1 = this->getInputGrad(0);
+ auto inG2 = this->getInputGrad(1);
+ CHECK(outG && outV && inV1 && inV2 && inG1 && inG2);
+ BufferArgs inputs;
+ BufferArgs outputs;
+ inputs.addArg(*outG);
+ inputs.addArg(*outV);
+ inputs.addArg(*inV1);
+ inputs.addArg(*inV2);
+ outputs.addArg(*inG1, ADD_TO);
+ outputs.addArg(*inG2, ADD_TO);
+
+ backward_[0]->calc(inputs, outputs);
}
}
diff --git a/paddle/gserver/layers/CosSimLayer.h b/paddle/gserver/layers/CosSimLayer.h
index 65549626098f084c5e1786885e06c1bdfa3ba74c..8afaee62c2dcacba006846df0111fcbe8f7575e4 100644
--- a/paddle/gserver/layers/CosSimLayer.h
+++ b/paddle/gserver/layers/CosSimLayer.h
@@ -28,7 +28,7 @@ namespace paddle {
*
* - Input1: A vector (batchSize * dataDim) *
* - Input2: A vector (batchSize * dataDim) or (1 * dataDim) *
- * - Output: A vector (dataDim * 1)
+ * - Output: A vector (batchSize * 1)
*
* The config file api is cos_sim.
*/
diff --git a/paddle/gserver/layers/CosSimVecMatLayer.cpp b/paddle/gserver/layers/CosSimVecMatLayer.cpp
index 5f652319e5620227fca166a8f72e5aed416bf5dd..aabafd473aa1e06a767d48d4c49b7b8662e992e7 100644
--- a/paddle/gserver/layers/CosSimVecMatLayer.cpp
+++ b/paddle/gserver/layers/CosSimVecMatLayer.cpp
@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/utils/Stat.h"
namespace paddle {
-
/**
* @brief A layer for computing cosine similarity between a vector
* and each row of a matrix
@@ -98,11 +97,22 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap,
dataDim,
/* trans= */ false,
useGpu_);
+
+ CHECK(tmpRow0 && tmpRow1 && tmpRow2 && tmpRow3 && tmpMtx0 && tmpMtx1);
+
+ createFunction(forward_,
+ "CosSimForward",
+ FuncConfig().set("scale", (real)config_.cos_scale()));
+ createFunction(backward_,
+ "CosSimBackward",
+ FuncConfig().set("scale", (real)config_.cos_scale()));
+
return true;
}
void CosSimVecMatLayer::forward(PassType passType) {
Layer::forward(passType);
+ CHECK_EQ(forward_.size(), 1) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
@@ -118,17 +128,25 @@ void CosSimVecMatLayer::forward(PassType passType) {
}
MatrixPtr outV = getOutputValue();
-
+ CHECK(outV && inV0 && inV1);
REGISTER_TIMER_INFO("FwCosVMTimer", getName().c_str());
for (size_t i = 0; i < batchSize; i++) {
tmpRow0->setData(inV0->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i));
- tmpRow2->cosSim(*(tmpMtx0), *(tmpRow0), config_.cos_scale());
+
+ BufferArgs inputs;
+ BufferArgs outputs;
+ inputs.addArg(*tmpMtx0);
+ inputs.addArg(*tmpRow0);
+ outputs.addArg(*tmpRow2, ASSIGN_TO);
+ forward_[0]->calc(inputs, outputs);
}
}
void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
+ CHECK_EQ(backward_.size(), 1) << "Only one forward function needed";
+
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr inG0 = getInputGrad(0);
@@ -137,27 +155,27 @@ void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
MatrixPtr outG = getOutputGrad();
size_t batchSize = inV0->getHeight();
-
+ CHECK(inV0 && inV1 && inG0 && inG1 && outV && outG);
REGISTER_TIMER_INFO("BwCosVMTimer", getName().c_str());
- if (inG0 && inG1) {
- for (size_t i = 0; i < batchSize; i++) {
- tmpRow0->setData(inV0->rowBuf(i));
- tmpRow1->setData(inG0->rowBuf(i));
- tmpMtx0->setData(inV1->rowBuf(i));
- tmpMtx1->setData(inG1->rowBuf(i));
- tmpRow2->setData(outV->rowBuf(i));
- tmpRow3->setData(outG->rowBuf(i));
-
- tmpRow3->cosSimDerivative(*(tmpRow2),
- *(tmpMtx0),
- *(tmpRow0),
- *(tmpMtx1),
- *(tmpRow1),
- config_.cos_scale());
- }
- } else {
- CHECK(!inG0 || !inG1) << "Not supported";
+ for (size_t i = 0; i < batchSize; i++) {
+ tmpRow0->setData(inV0->rowBuf(i));
+ tmpRow1->setData(inG0->rowBuf(i));
+ tmpMtx0->setData(inV1->rowBuf(i));
+ tmpMtx1->setData(inG1->rowBuf(i));
+ tmpRow2->setData(outV->rowBuf(i));
+ tmpRow3->setData(outG->rowBuf(i));
+
+ BufferArgs inputs;
+ BufferArgs outputs;
+ inputs.addArg(*tmpRow3);
+ inputs.addArg(*tmpRow2);
+ inputs.addArg(*tmpMtx0);
+ inputs.addArg(*tmpRow0);
+ outputs.addArg(*tmpMtx1, ADD_TO);
+ outputs.addArg(*tmpRow1, ADD_TO);
+
+ backward_[0]->calc(inputs, outputs);
}
}
diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp
index a8b53e2105b053399e62fba5321fd22c1fe4a50d..1964b2f8bfaebc49fe3073e03c949a8a9c3e385a 100644
--- a/paddle/math/Matrix.cpp
+++ b/paddle/math/Matrix.cpp
@@ -941,59 +941,6 @@ void GpuMatrix::softreluDerivative(Matrix& output) {
void GpuMatrix::scaledTanh(Matrix& output, real p1, real p2) {
BaseMatrix::scaledTanh(output, p1, p2);
}
-void GpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) {
- CHECK(output1.useGpu_ == true && output2.useGpu_ == true)
- << "Matrix type are not equal";
- size_t numSamples = getHeight();
- size_t dim = output1.getWidth();
- CHECK_EQ(getWidth(), 1UL);
- CHECK_EQ(output1.getHeight(), numSamples);
- CHECK_EQ(output1.getWidth(), output2.getWidth());
- real* out = getData();
- real* x = output1.getData();
- real* y = output2.getData();
- hl_cossim(out, x, y, dim, output1.getHeight(), output2.getHeight(), scale);
-}
-void GpuMatrix::cosSimDerivative(Matrix& output,
- Matrix& prevOut1,
- Matrix& prevOut2,
- Matrix& prevGrad1,
- Matrix& prevGrad2,
- real scale) {
- CHECK(output.useGpu_ == true && prevOut1.useGpu_ == true &&
- prevOut2.useGpu_ == true && prevGrad1.useGpu_ == true &&
- prevGrad2.useGpu_ == true)
- << "Matrix type are not equal";
- CHECK_EQ(getWidth(), 1UL);
- CHECK_EQ(output.getWidth(), 1UL);
-
- size_t numSamples = getHeight();
- CHECK_EQ(output.getHeight(), numSamples);
- CHECK_EQ(prevOut1.getHeight(), numSamples);
- CHECK_EQ(prevGrad1.getHeight(), numSamples);
-
- size_t dim = prevOut1.getWidth();
- CHECK_EQ(prevOut2.getWidth(), dim);
- CHECK_EQ(prevGrad1.getWidth(), dim);
- CHECK_EQ(prevGrad2.getWidth(), dim);
-
- real* grad = getData();
- real* out = output.getData();
- real* prevOutX = prevOut1.getData();
- real* prevOutY = prevOut2.getData();
- real* prevGradX = prevGrad1.getData();
- real* prevGradY = prevGrad2.getData();
- hl_cossim_derivative(grad,
- out,
- prevOutX,
- prevOutY,
- prevGradX,
- prevGradY,
- dim,
- prevOut1.getHeight(),
- prevOut2.getHeight(),
- scale);
-}
void GpuMatrix::randomizeUniform() {
CHECK(isContiguous());
@@ -3470,105 +3417,6 @@ void CpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) {
}
}
-void CpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) {
- size_t numSamples = getHeight();
- size_t dim = output1.getWidth();
- CHECK_EQ(getWidth(), 1UL);
- CHECK_EQ(output1.getHeight(), numSamples);
- CHECK_EQ(output1.getWidth(), output2.getWidth());
-
- real* out = getData();
- const real* x = output1.getData();
- const real* y = output2.getData();
- size_t yInc = dim;
- if (output2.getHeight() == 1LU) {
- yInc = 0;
- } else {
- CHECK_EQ(output2.getHeight(), numSamples);
- }
- for (size_t i = 0; i < numSamples; ++i, x += dim, y += yInc) {
- real squareSumX = 0;
- real squareSumY = 0;
- real xy = 0;
- for (size_t j = 0; j < dim; ++j) {
- squareSumX += _square(x[j]);
- squareSumY += _square(y[j]);
- xy += x[j] * y[j];
- }
- CHECK(squareSumX > 0 && squareSumY > 0);
- out[i] = scale * xy / (std::sqrt(squareSumX) * std::sqrt(squareSumY));
- }
-}
-
-void CpuMatrix::cosSimDerivative(Matrix& output,
- Matrix& prevOut1,
- Matrix& prevOut2,
- Matrix& prevGrad1,
- Matrix& prevGrad2,
- real scale) {
- CHECK(output.useGpu_ == false) << "Matrix type are not equal";
-
- CHECK_EQ(getWidth(), 1UL);
- CHECK_EQ(output.getWidth(), 1UL);
-
- size_t numSamples = getHeight();
- CHECK_EQ(output.getHeight(), numSamples);
- CHECK_EQ(prevOut1.getHeight(), numSamples);
- CHECK_EQ(prevGrad1.getHeight(), numSamples);
-
- size_t dim = prevOut1.getWidth();
- CHECK_EQ(prevOut2.getWidth(), dim);
- CHECK_EQ(prevGrad1.getWidth(), dim);
- CHECK_EQ(prevGrad2.getWidth(), dim);
-
- const real* grad = getData();
- const real* out = output.getData();
- const real* prevOutX = prevOut1.getData();
- const real* prevOutY = prevOut2.getData();
- real* prevGradX = prevGrad1.getData();
- real* prevGradY = prevGrad2.getData();
- size_t yInc = dim;
- if (prevOut2.getHeight() == 1LU) {
- yInc = 0;
- CHECK_EQ(prevGrad2.getHeight(), 1LU);
- } else {
- CHECK_EQ(prevOut2.getHeight(), numSamples);
- CHECK_EQ(prevGrad2.getHeight(), numSamples);
- }
- for (size_t i = 0; i < numSamples; ++i,
- prevOutX += dim,
- prevOutY += yInc,
- prevGradX += dim,
- prevGradY += yInc) {
- real squareSumX = 0;
- real squareSumY = 0;
- real xy = 0;
- for (size_t j = 0; j < dim; ++j) {
- squareSumX += _square(prevOutX[j]);
- squareSumY += _square(prevOutY[j]);
- xy += prevOutX[j] * prevOutY[j];
- }
- CHECK(squareSumX > 0 && squareSumY > 0);
- if (xy == 0) {
- real reciprocal = 1.0f / (std::sqrt(squareSumX) * std::sqrt(squareSumY));
- for (size_t j = 0; j < dim; ++j) {
- prevGradX[j] += scale * grad[i] * prevOutY[j] * reciprocal;
- prevGradY[j] += scale * grad[i] * prevOutX[j] * reciprocal;
- }
- } else {
- real reciprocalXY = 1.0f / xy;
- real reciprocalSquareSumX = 1.0f / squareSumX;
- real reciprocalSquareSumY = 1.0f / squareSumY;
- for (size_t j = 0; j < dim; ++j) {
- prevGradX[j] += out[i] * grad[i] * (prevOutY[j] * reciprocalXY -
- prevOutX[j] * reciprocalSquareSumX);
- prevGradY[j] += out[i] * grad[i] * (prevOutX[j] * reciprocalXY -
- prevOutY[j] * reciprocalSquareSumY);
- }
- }
- }
-}
-
void CpuMatrix::sumOfSquares(Matrix& output, Matrix& label) {
CHECK(output.useGpu_ == false && label.useGpu_ == false)
<< "Matrix type are not equal";
diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h
index c92c0a272d5a72868bd61035d77aa4ed0fad7a7c..ea4bbb86b057b526c5ea294b2cd835aef65de58d 100644
--- a/paddle/math/Matrix.h
+++ b/paddle/math/Matrix.h
@@ -799,26 +799,6 @@ public:
LOG(FATAL) << "Not implemented";
}
- /**
- * cosine similarity, for each row i,
- * this[i] = cos(output1[i], output2[i])
- *
- * output2 can only have one row, then for each row i,
- * this[i] = cos(output1[i], output2[0])
- */
- virtual void cosSim(Matrix& output1, Matrix& output2, real scale = 1.0f) {
- LOG(FATAL) << "Not implemented";
- }
-
- virtual void cosSimDerivative(Matrix& output,
- Matrix& prevOut1,
- Matrix& prevOut2,
- Matrix& prevGrad1,
- Matrix& prevGrad2,
- real scale = 1.0f) {
- LOG(FATAL) << "Not implemented";
- }
-
/// print out the values of elements to os
virtual void print(std::ostream& os) const {
LOG(FATAL) << "Not implemented";
@@ -1324,14 +1304,6 @@ public:
void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2);
- void cosSim(Matrix& output1, Matrix& output2, real scale);
- void cosSimDerivative(Matrix& output,
- Matrix& prevOut1,
- Matrix& prevOut2,
- Matrix& prevGrad1,
- Matrix& prevGrad2,
- real scale);
-
virtual void print(std::ostream& os) const;
virtual void print(std::ostream& os, size_t height, size_t width) const;
@@ -1752,14 +1724,6 @@ public:
void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2);
- void cosSim(Matrix& output1, Matrix& output2, real scale);
- void cosSimDerivative(Matrix& output,
- Matrix& prevOut1,
- Matrix& prevOut2,
- Matrix& prevGrad1,
- Matrix& prevGrad2,
- real scale);
-
void print(std::ostream& os) const;
void print(std::ostream& os, size_t height, size_t width) const;
void printOneRow(std::ostream& os, size_t idx) const;
diff --git a/paddle/math/tests/test_Matrix.cpp b/paddle/math/tests/test_Matrix.cpp
index a4084bdf7c6953651bfd9714fd8a5c930f774fe6..1c21da5b76e95603258a5006d0c57b00126e65b9 100644
--- a/paddle/math/tests/test_Matrix.cpp
+++ b/paddle/math/tests/test_Matrix.cpp
@@ -181,28 +181,6 @@ TEST(Matrix, copyByRowIndex) {
}
}
-void testCosSim(int heightX, int heightY, int width, real scale) {
- AutoCompare test(heightX, 1);
- CpuMatrix arg1(heightX, width);
- CpuMatrix arg2(heightY, width);
- arg1.randomizeUniform();
- arg2.randomizeUniform();
- arg2.add(-0.5);
- test.cmpWithArg(&Matrix::cosSim, arg1, arg2, scale);
-}
-
-TEST(Matrix, cosSim) {
- for (auto heightX : {10, 100, 1000}) {
- for (auto heightY : {1, heightX}) {
- for (auto width : {10, 100, 1000}) {
- for (auto scale : {1.0, 2.0}) {
- testCosSim(heightX, heightY, width, scale);
- }
- }
- }
- }
-}
-
void testParamReluForward(int height, int width, int w_height, int w_width) {
AutoCompare test(height, width);
CpuMatrix arg1(height, width);
diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp
index e024f2cf1b913f56301ac7b3380f0c382818f413..6caaea443c1df756bfeb775154e8a90400cc3211 100644
--- a/paddle/math/tests/test_matrixCompare.cpp
+++ b/paddle/math/tests/test_matrixCompare.cpp
@@ -720,61 +720,6 @@ TEST(Matrix, sequenceAvgForward) {
}
}
-void testCosSimDerivate(int heightX, int heightY, int width, real scale) {
- MatrixPtr prevOutX = CpuMatrix::create(heightX, width, false, false);
- MatrixPtr prevOutY = CpuMatrix::create(heightY, width, false, false);
- MatrixPtr grad = CpuMatrix::create(heightX, 1, false, false);
- MatrixPtr output = CpuMatrix::create(heightX, 1, false, false);
- MatrixPtr prevGradX = CpuMatrix::create(heightX, width, false, false);
- MatrixPtr prevGradY = CpuMatrix::create(heightY, width, false, false);
-
- prevOutX->randomizeUniform();
- prevOutY->randomizeUniform();
- grad->randomizeUniform();
- output->randomizeUniform();
- prevGradX->randomizeUniform();
- prevGradY->randomizeUniform();
-
- MatrixPtr prevOutXGpu = GpuMatrix::create(heightX, width, false, true);
- MatrixPtr prevOutYGpu = GpuMatrix::create(heightY, width, false, true);
- MatrixPtr gradGpu = GpuMatrix::create(heightX, 1, false, true);
- MatrixPtr outputGpu = GpuMatrix::create(heightX, 1, false, true);
- MatrixPtr prevGradXGpu = GpuMatrix::create(heightX, width, false, true);
- MatrixPtr prevGradYGpu = GpuMatrix::create(heightY, width, false, true);
-
- prevOutXGpu->copyFrom(*prevOutX);
- prevOutYGpu->copyFrom(*prevOutY);
- gradGpu->copyFrom(*grad);
- outputGpu->copyFrom(*output);
- prevGradXGpu->copyFrom(*prevGradX);
- prevGradYGpu->copyFrom(*prevGradY);
-
- grad->cosSimDerivative(
- *output, *prevOutX, *prevOutY, *prevGradX, *prevGradY, scale);
-
- gradGpu->cosSimDerivative(*outputGpu,
- *prevOutXGpu,
- *prevOutYGpu,
- *prevGradXGpu,
- *prevGradYGpu,
- scale);
-
- TensorCheckErr(*prevGradX, *prevGradXGpu);
- TensorCheckErr(*prevGradY, *prevGradYGpu);
-}
-
-TEST(Matrix, cosSimDerivate) {
- for (auto heightX : {1, 10, 100}) {
- for (auto heightY : {1, heightX}) {
- for (auto width : {1, 10, 100}) {
- for (auto scale : {1.0, 2.0}) {
- testCosSimDerivate(heightX, heightY, width, scale);
- }
- }
- }
- }
-}
-
void testParamReluBackwardDiff(int height,
int width,
int w_height,
diff --git a/paddle/utils/Util.cpp b/paddle/utils/Util.cpp
index 220aac1ff11e0ff263df8459f539237944b94c81..dbab4ec43ca2fa691445131d2cb14f51721a2e4c 100644
--- a/paddle/utils/Util.cpp
+++ b/paddle/utils/Util.cpp
@@ -289,6 +289,7 @@ void mkDir(const char* filename) {
void mkDirRecursively(const char* dir) {
struct stat sb;
+ if (*dir == 0) return; // empty string
if (!stat(dir, &sb)) return;
mkDirRecursively(path::dirname(dir).c_str());
diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py
index 0ea8fc77eef9f5daeaa262ce7808db3e980f991c..ab9a2562dcccb394c0b24741ceeb10061e40cb9a 100644
--- a/python/paddle/trainer_config_helpers/data_sources.py
+++ b/python/paddle/trainer_config_helpers/data_sources.py
@@ -201,7 +201,7 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None):
data.load_data_module = load_data_module
data.load_data_object = load_data_object
data.load_data_args = load_data_args
- data.async_load_data = True
+ data.async_load_data = False
return data
define_py_data_sources(
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index 66fa58ac91e33bfeac37d1bfbdad8dab4789c4bd..1fdc4c462363712e8b5b4dee10d0aaa26f4deffa 100755
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -3677,26 +3677,27 @@ def pad_layer(input,
For example,
- .. code-block::
-
- input(2,2,2,3) = [
- [ [[1,2,3], [3,4,5]],
- [[2,3,5], [1,6,7]] ],
- [ [[4,3,1], [1,8,7]],
- [[3,8,9], [2,3,5]] ]
- ]
-
- pad_c=[1,1], pad_h=[0,0], pad_w=[0,0]
- output(2,4,2,3) = [
- [ [[0,0,0], [0,0,0]],
- [[1,2,3], [3,4,5]],
- [[2,3,5], [1,6,7]],
- [[0,0,0], [0,0,0]] ],
- [ [[0,0,0], [0,0,0]],
- [[4,3,1], [1,8,7]],
- [[3,8,9], [2,3,5]],
- [[0,0,0], [0,0,0]] ]
- ]
+ .. code-block:: python
+
+ input(2,2,2,3) = [
+ [ [[1,2,3], [3,4,5]],
+ [[2,3,5], [1,6,7]] ],
+ [ [[4,3,1], [1,8,7]],
+ [[3,8,9], [2,3,5]] ]
+ ]
+
+ pad_c=[1,1], pad_h=[0,0], pad_w=[0,0]
+
+ output(2,4,2,3) = [
+ [ [[0,0,0], [0,0,0]],
+ [[1,2,3], [3,4,5]],
+ [[2,3,5], [1,6,7]],
+ [[0,0,0], [0,0,0]] ],
+ [ [[0,0,0], [0,0,0]],
+ [[4,3,1], [1,8,7]],
+ [[3,8,9], [2,3,5]],
+ [[0,0,0], [0,0,0]] ]
+ ]
The simply usage is:
@@ -4191,13 +4192,7 @@ def block_expand_layer(input,
@wrap_name_default()
@layer_support()
-def maxout_layer(input,
- groups,
- num_channels=None,
- size_x=None,
- size_y=None,
- name=None,
- layer_attr=None):
+def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None):
"""
A layer to do max out on conv layer output.
- Input: output of a conv layer.
@@ -4227,12 +4222,6 @@ def maxout_layer(input,
:type num_channels: int|None
:param groups: The group number of input layer.
:type groups: int
- :param size_x: conv output width. If None will be set
- automatically from previous output.
- :type size_x: int|None
- :param size_y: conv output height. If None will be set
- automatically from previous output.
- :type size_y: int|None
:param name: The name of this layer, which can not specify.
:type name: None|basestring.
:param layer_attr: Extra Layer attribute.
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr
index 1cfb92255aa92fa3fbc16a816851a5c2f81c2b56..569b0b945a762e8b596e197adc06df64e33311af 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr
@@ -19,7 +19,7 @@ model_config {
data_config {
type: "py2"
files: "train.list"
- async_load_data: true
+ async_load_data: false
for_test: false
load_data_module: "a"
load_data_object: "c"
@@ -58,7 +58,7 @@ opt_config {
test_data_config {
type: "py2"
files: "test.list"
- async_load_data: true
+ async_load_data: false
for_test: true
load_data_module: "b"
load_data_object: "d"