提交 236af566 编写于 作者: Y Yibing Liu

separate index tensor from candidate tensors in multiplex_op

...@@ -390,4 +390,125 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数 ...@@ -390,4 +390,125 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
* 如果发现最早的报错就是网络通信的问题,很有可能是非独占方式执行导致的端口冲突,可以联系OP,看当前MPI集群是否支持resource=full参数提交,如果支持增加此参数提交,并更换job 端口。 * 如果发现最早的报错就是网络通信的问题,很有可能是非独占方式执行导致的端口冲突,可以联系OP,看当前MPI集群是否支持resource=full参数提交,如果支持增加此参数提交,并更换job 端口。
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。 * 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
\ No newline at end of file
19. PaddlePaddle如何输出多个层
------------------------------
* 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下:
.. code-block:: python
inferer = paddle.inference.Inference(output_layer=[layer1, layer2], parameters=parameters)
* 指定要输出的字段进行输出。以输出 :code:`value` 字段为例,代码如下:
.. code-block:: python
out = inferer.infer(input=data_batch, flatten_result=False, field=["value"])
这里设置 :code:`flatten_result=False`,得到的输出结果是元素个数等于输出字段数的 :code:`list`,该 :code:`list` 的每个元素是由所有输出层相应字段结果组成的 :code:`list`,每个字段结果的类型是 :code:`numpy.array`。:code:`flatten_result` 的默认值为 :code:`True`,该情况下,PaddlePaddle会分别对每个字段将所有输出层的结果按行进行拼接,如果各输出层该字段 :code:`numpy.array` 结果的相应维数不匹配,程序将不能正常运行。
20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用
-------------------------------------------------------------
* :code:`paddle.layer.memory` 用于获取特定layer上一时间步的输出,该layer是通过参数 :code:`name` 指定,即,:code:`paddle.layer.memory` 会关联参数 :code:`name` 取值相同的layer,并将该layer上一时间步的输出作为自身当前时间步的输出。
* PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。
21. dropout 使用
-----------------
* 在PaddlePaddle中使用dropout有两种方式
* 在相应layer的 :code:`layer_atter` 设置 :code:`drop_rate`,以 :code:`paddle.layer.fc` 为例,代码如下:
.. code-block:: python
fc = paddle.layer.fc(input=input, layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=0.5))
* 使用 :code:`paddle.layer.dropout`,以 :code:`paddle.layer.fc` 为例,代码如下:
.. code-block:: python
fc = paddle.layer.fc(input=input)
drop_fc = paddle.layer.dropout(input=fc, dropout_rate=0.5)
* :code:`paddle.layer.dropout` 实际上使用了 :code:`paddle.layer.add_to`,并在该layer里采用第一种方式设置 :code:`drop_rate` 来使用dropout的。这种方式对内存消耗较大。
* PaddlePaddle在激活函数里实现dropout,而不是在layer里实现。
* :code:`paddle.layer.lstmemory`、:code:`paddle.layer.grumemory`、:code:`paddle.layer.recurrent` 不是通过一般的方式来实现对输出的激活,所以不能采用第一种方式在这几个layer里设置 :code:`drop_rate` 来使用dropout。若要对这几个layer使用dropout,可采用第二种方式,即使用 :code:`paddle.layer.dropout`。
22. 如何设置学习率退火(learning rate annealing)
------------------------------------------------
在相应的优化算法里设置learning_rate_schedule及相关参数,以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_decay_a=0.5,
learning_rate_decay_b=0.75,
learning_rate_schedule="poly",)
PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedule及其对应学习率计算方式如下:
* "constant"
lr = learning_rate
* "poly"
lr = learning_rate * pow(1 + learning_rate_decay_a * num_samples_processed, -learning_rate_decay_b)
其中,num_samples_processed为已训练样本数,下同。
* "caffe_poly"
lr = learning_rate * pow(1.0 - num_samples_processed / learning_rate_decay_a, learning_rate_decay_b)
* "exp"
lr = learning_rate * pow(learning_rate_decay_a, num_samples_processed / learning_rate_decay_b)
* "discexp"
lr = learning_rate * pow(learning_rate_decay_a, floor(num_samples_processed / learning_rate_decay_b))
* "linear"
lr = max(learning_rate - learning_rate_decay_a * num_samples_processed, learning_rate_decay_b)
* "manual"
这是一种按已训练样本数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_schedule="manual",
learning_rate_args="1000:1.0,2000:0.9,3000:0.8",)
在该示例中,当已训练样本数小于等于1000时,学习率为 :code:`1e-3 * 1.0`;当已训练样本数大于1000小于等于2000时,学习率为 :code:`1e-3 * 0.9`;当已训练样本数大于2000时,学习率为 :code:`1e-3 * 0.8`。
* "pass_manual"
这是一种按已训练pass数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_schedule="manual",
learning_rate_args="1:1.0,2:0.9,3:0.8",)
在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。
23. 出现 :code:`Duplicated layer name` 错误怎么办
--------------------------------------------------
出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。
...@@ -38,10 +38,10 @@ class CropKernel : public framework::OpKernel { ...@@ -38,10 +38,10 @@ class CropKernel : public framework::OpKernel {
auto out_stride = framework::stride(out->dims()); auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr<std::vector<int>>("offsets"); auto offsets = context.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x->dims().size(), offsets.size(), x->dims().size(), static_cast<int64_t>(offsets.size()),
"Offsets size should be equal to dimension size of input tensor."); "Offsets size should be equal to dimension size of input tensor.");
int64_t offset = 0; int64_t offset = 0;
for (int i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]); offset += (x_stride[i] * offsets[i]);
} }
StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride, StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride,
...@@ -57,7 +57,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { ...@@ -57,7 +57,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
d_x->mutable_data<T>(context.GetPlace()); d_x->mutable_data<T>(context.GetPlace());
auto offsets = context.Attr<std::vector<int>>("offsets"); auto offsets = context.Attr<std::vector<int>>("offsets");
Eigen::array<std::pair<int, int>, D> paddings; Eigen::array<std::pair<int, int>, D> paddings;
for (int i = 0; i < D; ++i) { for (size_t i = 0; i < D; ++i) {
paddings[i].first = offsets[i]; paddings[i].first = offsets[i];
paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i]; paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i];
} }
......
...@@ -48,6 +48,32 @@ void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context, ...@@ -48,6 +48,32 @@ void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
beta, C, ldc); beta, C, ldc);
} }
template <>
void gemm<platform::CPUPlace, float>(const platform::DeviceContext& context,
const bool transA, const bool transB,
const int M, const int N, const int K,
const float alpha, const float* A,
const int lda, const float* B,
const int ldb, const float beta, float* C,
const int ldc) {
cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
const bool transA, const bool transB,
const int M, const int N, const int K,
const double alpha, const double* A,
const int lda, const double* B,
const int ldb, const double beta,
double* C, const int ldc) {
cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <> template <>
void matmul<platform::CPUPlace, float>( void matmul<platform::CPUPlace, float>(
const platform::DeviceContext& context, const framework::Tensor& matrix_a, const platform::DeviceContext& context, const framework::Tensor& matrix_a,
......
...@@ -63,6 +63,42 @@ void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context, ...@@ -63,6 +63,42 @@ void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
} }
template <>
void gemm<platform::GPUPlace, float>(const platform::DeviceContext& context,
const bool transA, const bool transB,
const int M, const int N, const int K,
const float alpha, const float* A,
const int lda, const float* B,
const int ldb, const float beta, float* C,
const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
}
template <>
void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
const bool transA, const bool transB,
const int M, const int N, const int K,
const double alpha, const double* A,
const int lda, const double* B,
const int ldb, const double beta,
double* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
}
template <> template <>
void matmul<platform::GPUPlace, float>( void matmul<platform::GPUPlace, float>(
const platform::DeviceContext& context, const framework::Tensor& matrix_a, const platform::DeviceContext& context, const framework::Tensor& matrix_a,
......
...@@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, ...@@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const T alpha, const T* A, const T* B, const T beta, T* C); const T alpha, const T* A, const T* B, const T beta, T* C);
// gemm wrapper with stride args for matrix uncontinuous in memory
template <typename Place, typename T>
void gemm(const platform::DeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const T alpha, const T* A, const int lda, const T* B, const int ldb,
const T beta, T* C, const int ldc);
// matrix multiply with continuous memory // matrix multiply with continuous memory
template <typename Place, typename T> template <typename Place, typename T>
void matmul(const platform::DeviceContext& context, void matmul(const platform::DeviceContext& context,
......
...@@ -72,4 +72,174 @@ TEST(math_function, trans_mul_notrans) { ...@@ -72,4 +72,174 @@ TEST(math_function, trans_mul_notrans) {
EXPECT_EQ(out_ptr[8], 29); EXPECT_EQ(out_ptr[8], 29);
delete gpu_place; delete gpu_place;
} }
TEST(math_function, gemm_notrans_cublas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor input3_gpu;
int m = 2;
int n = 3;
int k = 3;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr1[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr1, 6 * sizeof(float));
float* input2_ptr = input2.mutable_data<float>({3, 4}, *cpu_place);
float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
memcpy(input2_ptr, arr2, 12 * sizeof(float));
float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place);
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input2, *gpu_place);
input3_gpu.CopyFrom<float>(input3, *gpu_place);
float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place);
paddle::operators::math::gemm<paddle::platform::GPUPlace, float>(
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
input3.CopyFrom<float>(input3_gpu, *cpu_place);
// numpy code:
// a = np.arange(6).reshape(2, 3)
// b = np.arange(12).reshape(3, 4)[:, 1:]
// c = np.arange(8).reshape(2, 4)[:, 1:]
// out = np.arange(8).reshape(2, 4)
// out[:, 1:] = np.dot(a, b) + c
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
EXPECT_EQ(input3_ptr[3], 32);
EXPECT_EQ(input3_ptr[4], 4);
EXPECT_EQ(input3_ptr[5], 73);
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
delete gpu_place;
}
TEST(math_function, gemm_trans_cublas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor input3_gpu;
int m = 2;
int n = 3;
int k = 3;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr1[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr1, 6 * sizeof(float));
float* input2_ptr = input2.mutable_data<float>({4, 3}, *cpu_place);
float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11};
memcpy(input2_ptr, arr2, 12 * sizeof(float));
float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place);
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input2, *gpu_place);
input3_gpu.CopyFrom<float>(input3, *gpu_place);
float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place);
paddle::operators::math::gemm<paddle::platform::GPUPlace, float>(
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
input3.CopyFrom<float>(input3_gpu, *cpu_place);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
EXPECT_EQ(input3_ptr[3], 32);
EXPECT_EQ(input3_ptr[4], 4);
EXPECT_EQ(input3_ptr[5], 73);
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
delete gpu_place;
}
#endif #endif
TEST(math_function, gemm_notrans_cblas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
int m = 2;
int n = 3;
int k = 3;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr1[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr1, 6 * sizeof(float));
float* input2_ptr = input2.mutable_data<float>({3, 4}, *cpu_place);
float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
memcpy(input2_ptr, arr2, 12 * sizeof(float));
float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place);
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUPlace, float>(
context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1,
input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
EXPECT_EQ(input3_ptr[3], 32);
EXPECT_EQ(input3_ptr[4], 4);
EXPECT_EQ(input3_ptr[5], 73);
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
}
TEST(math_function, gemm_trans_clbas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
int m = 2;
int n = 3;
int k = 3;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr1[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr1, 6 * sizeof(float));
float* input2_ptr = input2.mutable_data<float>({4, 3}, *cpu_place);
float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11};
memcpy(input2_ptr, arr2, 12 * sizeof(float));
float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place);
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUPlace, float>(
context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1,
input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
EXPECT_EQ(input3_ptr[3], 32);
EXPECT_EQ(input3_ptr[4], 4);
EXPECT_EQ(input3_ptr[5], 73);
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
}
...@@ -25,24 +25,30 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -25,24 +25,30 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"),
"Input(Ids) shouldn't be null.");
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null."); "MultiInput(X) shouldn't be empty.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) shouldn't be null."); "Output(Out) shouldn't be null.");
auto ids_dim = ctx.Input<Tensor>("Ids")->dims();
PADDLE_ENFORCE(
ids_dim.size() == 2 && ids_dim[1] == 1,
"The index tensor must be a vector with size batchSize x 1.");
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto *out = ctx.Output<Tensor>("Out"); auto *out = ctx.Output<Tensor>("Out");
auto num_ins = ins.size(); auto num_ins = ins.size();
PADDLE_ENFORCE(num_ins > 2, PADDLE_ENFORCE(num_ins > 1,
"multiplex operator should have more than 2 inputs."); "multiplex operator should have more than "
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, "one candidate input tensors.");
"The first input must be a index vector.");
auto in_dim = ins[1]->dims();
for (size_t i = 2; i < num_ins; i++) { auto in_dim = ins[0]->dims();
PADDLE_ENFORCE(in_dim.size() == 2, "Candidate tensors must be matrix.");
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins[i]->dims(); auto dim = ins[i]->dims();
PADDLE_ENFORCE(in_dim == dim, PADDLE_ENFORCE(in_dim == dim,
"All the input tensors except the first one must have the " "All the candidate tensors must have the same size.");
"same size.");
} }
out->Resize(in_dim); out->Resize(in_dim);
} }
...@@ -53,25 +59,26 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -53,25 +59,26 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
MultiplexOpMaker(framework::OpProto *proto, MultiplexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.")
.AsDuplicable();
AddOutput("Out", "The output tensor of multiplex operator."); AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator AddComment(R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first Multiplex multiple tensors according to the index provided by the first
input tensor. input tensor.
ins[0]: the index tensor. Ids: the index tensor.
ins[1:N]: the candidate output tensors. X[0 : N - 1]: the candidate tensors for output (N >= 2).
For each index i from 0 to batchSize - 1, the output is the i-th row of the For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (index[i] + 1)-th tensor. the (Ids[i])-th tensor.
For i-th row of the output tensor: For i-th row of the output tensor:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{0}.width - 1)
where y is the output tensor. `x_{k}` is the k-th input tensor where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = x{0}[i] + 1`. and `k = Ids[i]`.
)DOC"); )DOC");
} }
}; };
...@@ -90,8 +97,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -90,8 +97,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) shouldn't be null.");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X")); auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
// don't compute gradient for index (ins[0]) // No need to compute gradient for Input(Ids)
for (size_t i = 1; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (d_ins[i]) { if (d_ins[i]) {
d_ins[i]->Resize(ins[i]->dims()); d_ins[i]->Resize(ins[i]->dims());
} }
......
...@@ -25,21 +25,23 @@ class MultiplexGPUKernel : public framework::OpKernel { ...@@ -25,21 +25,23 @@ class MultiplexGPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto* ids = ctx.Input<Tensor>("Ids");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[0]->dims()[1];
// copy index to cpu // copy index to cpu
Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace()); index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
auto* index = index_t_cpu.data<T>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream(); .stream();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
size_t k = (size_t)index[i] + 1; int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
PADDLE_ENFORCE_LT(k, ins.size(), PADDLE_ENFORCE_LT(k, ins.size(),
"index exceeds the number of candidate tensors."); "index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place, memory::Copy(place, out->data<T>() + i * cols, place,
...@@ -54,8 +56,9 @@ class MultiplexGradGPUKernel : public framework::OpKernel { ...@@ -54,8 +56,9 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto* ids = ctx.Input<Tensor>("Ids");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X")); auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) { for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) { if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace()); d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]); auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
...@@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel { ...@@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
} }
} }
auto rows = ins[1]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[0]->dims()[1];
// copy index to cpu // copy index to cpu
Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace()); index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
auto* index = index_t_cpu.data<T>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream(); .stream();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
size_t k = (size_t)index[i] + 1; size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) { if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place, memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T), stream); d_out->data<T>() + i * cols, cols * sizeof(T), stream);
......
...@@ -27,17 +27,19 @@ class MultiplexCPUKernel : public framework::OpKernel { ...@@ -27,17 +27,19 @@ class MultiplexCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto ids = ctx.Input<framework::Tensor>("Ids");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[0]->dims()[1];
auto* index = ins[0]->data<T>(); auto index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
size_t k = (size_t)index[i] + 1; int32_t k = index[i];
PADDLE_ENFORCE_LT(k, ins.size(), PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
PADDLE_ENFORCE_LT(static_cast<size_t>(k), ins.size(),
"index exceeds the number of candidate tensors."); "index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place, memory::Copy(place, out->data<T>() + i * cols, place,
ins[k]->data<T>() + i * cols, cols * sizeof(T)); ins[k]->data<T>() + i * cols, cols * sizeof(T));
...@@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins = auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) { for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) { if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace()); d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]); auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
...@@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
} }
} }
auto rows = ins[1]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[0]->dims()[1];
auto* index = ins[0]->data<T>(); auto* index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
size_t k = (size_t)index[i] + 1; size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) { if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place, memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T)); d_out->data<T>() + i * cols, cols * sizeof(T));
...@@ -74,5 +77,5 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -74,5 +77,5 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
} }
} }
}; };
} } // namespace operators
} } // namespace paddle
...@@ -921,7 +921,7 @@ def data_layer(name, size, depth=None, height=None, width=None, ...@@ -921,7 +921,7 @@ def data_layer(name, size, depth=None, height=None, width=None,
data = data_layer(name="input", size=1000) data = data_layer(name="input", size=1000)
:param name: The name of this layer. It is optional. :param name: The name of this layer.
:type name: basestring :type name: basestring
:param size: Size of this data layer. :param size: Size of this data layer.
:type size: int :type size: int
...@@ -3668,6 +3668,7 @@ def gru_step_naive_layer(input, ...@@ -3668,6 +3668,7 @@ def gru_step_naive_layer(input,
:param param_attr: :param param_attr:
:param layer_attr: :param layer_attr:
:return: :return:
:rtype: LayerOutput
""" """
if input.size % 3 != 0: if input.size % 3 != 0:
raise ValueError("GruStep input size must be divided by 3") raise ValueError("GruStep input size must be divided by 3")
......
...@@ -6,20 +6,22 @@ from op_test import OpTest ...@@ -6,20 +6,22 @@ from op_test import OpTest
class TestMultiplexOp(OpTest): class TestMultiplexOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "multiplex" self.op_type = "multiplex"
rows = 3 rows = 4
index = np.array([3, 1, 0]) index = np.arange(0, rows).astype('int32')
np.random.shuffle(index)
index = np.reshape(index, (rows, 1))
ins1 = np.random.random((rows, 10)).astype("float32") ins1 = np.random.random((rows, 10)).astype("float32")
ins2 = np.random.random((rows, 10)).astype("float32") ins2 = np.random.random((rows, 10)).astype("float32")
ins3 = np.random.random((rows, 10)).astype("float32") ins3 = np.random.random((rows, 10)).astype("float32")
ins4 = np.random.random((rows, 10)).astype("float32") ins4 = np.random.random((rows, 10)).astype("float32")
self.inputs = { self.inputs = {
'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3), 'Ids': index,
('x4', ins4)] 'X': [('x1', ins1), ('x2', ins2), ('x3', ins3), ('x4', ins4)]
} }
# multiplex output # multiplex output
output = np.zeros_like(ins1) output = np.zeros_like(ins1)
for i in range(0, rows): for i in range(0, rows):
k = index[i] + 1 k = index[i][0]
output[i] = self.inputs['X'][k][1][i] output[i] = self.inputs['X'][k][1][i]
self.outputs = {'Out': output} self.outputs = {'Out': output}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册