diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc index 4c65b60d2349d2989128f4b1da705ea18391b8a3..c03dc3e4fb07ac6ecde42be93a1138d91778edf4 100644 --- a/paddle/operators/conv_cudnn_op.cc +++ b/paddle/operators/conv_cudnn_op.cc @@ -40,7 +40,8 @@ REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, ops::ConvOpGrad); REGISTER_OP_CPU_KERNEL(conv_cudnn, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv_cudnn_grad, - ops::GemmConvGradKernel); + conv_cudnn_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_cudnn_op.cu.cc b/paddle/operators/conv_cudnn_op.cu.cc index 4900f7b086c869b496c492743c71ab7047c5f672..5eaf6b33704eb371fff4b949c6cc32a7a5dbc812 100644 --- a/paddle/operators/conv_cudnn_op.cu.cc +++ b/paddle/operators/conv_cudnn_op.cu.cc @@ -259,6 +259,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel); +REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel, + paddle::operators::CudnnConvOpKernel); REGISTER_OP_GPU_KERNEL(conv_cudnn_grad, - paddle::operators::CudnnConvGradOpKernel); + paddle::operators::CudnnConvGradOpKernel, + paddle::operators::CudnnConvGradOpKernel); diff --git a/paddle/operators/conv_transpose_cudnn_op.cc b/paddle/operators/conv_transpose_cudnn_op.cc index dbd1bc3c3bc2d026f13ddcf62919db6cf7d87bc5..0192178ce3a0a47196232f0723baec8324bea60b 100644 --- a/paddle/operators/conv_transpose_cudnn_op.cc +++ b/paddle/operators/conv_transpose_cudnn_op.cc @@ -61,10 +61,12 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp, REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad, @@ -72,7 +74,9 @@ REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, REGISTER_OP_CPU_KERNEL( conv3d_transpose_cudnn, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv3d_transpose_cudnn_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv_transpose_cudnn_op.cu.cc b/paddle/operators/conv_transpose_cudnn_op.cu.cc index e2ba77086e737a07471f14e483cbd32ab1d4ee12..494904fe524ae30a5032e489a0c5f20179d8e8ce 100644 --- a/paddle/operators/conv_transpose_cudnn_op.cu.cc +++ b/paddle/operators/conv_transpose_cudnn_op.cu.cc @@ -235,11 +235,15 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, - ops::CudnnConvTransposeOpKernel); + ops::CudnnConvTransposeOpKernel, + ops::CudnnConvTransposeOpKernel); REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, - ops::CudnnConvTransposeGradOpKernel); + ops::CudnnConvTransposeGradOpKernel, + ops::CudnnConvTransposeGradOpKernel); REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn, - ops::CudnnConvTransposeOpKernel); + ops::CudnnConvTransposeOpKernel, + ops::CudnnConvTransposeOpKernel); REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad, - ops::CudnnConvTransposeGradOpKernel); + ops::CudnnConvTransposeGradOpKernel, + ops::CudnnConvTransposeGradOpKernel); diff --git a/paddle/operators/pool_cudnn_op.cc b/paddle/operators/pool_cudnn_op.cc index f962d9e3e6abde14ce21eb0102f10d139fdb160e..be9fcc5661f420aadf908cf80cce6c963008b0e4 100644 --- a/paddle/operators/pool_cudnn_op.cc +++ b/paddle/operators/pool_cudnn_op.cc @@ -20,6 +20,18 @@ REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad, ops::PoolOpGrad); REGISTER_OP_CPU_KERNEL(pool2d_cudnn, - ops::PoolKernel); + ops::PoolKernel, + ops::PoolKernel); REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad, - ops::PoolGradKernel) + ops::PoolGradKernel, + ops::PoolGradKernel) + +REGISTER_OP(pool3d_cudnn, ops::PoolOp, ops::Pool3dOpMaker, pool3d_cudnn_grad, + ops::PoolOpGrad); + +REGISTER_OP_CPU_KERNEL(pool3d_cudnn, + ops::PoolKernel, + ops::PoolKernel); +REGISTER_OP_CPU_KERNEL(pool3d_cudnn_grad, + ops::PoolGradKernel, + ops::PoolGradKernel) diff --git a/paddle/operators/pool_cudnn_op.cu.cc b/paddle/operators/pool_cudnn_op.cu.cc index f9d8af3e1c5db49873979fdfeb17a32d16341a1a..66dd194ccd5ed629c5861552a7c124dc911362d7 100644 --- a/paddle/operators/pool_cudnn_op.cu.cc +++ b/paddle/operators/pool_cudnn_op.cu.cc @@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel { ScopedTensorDescriptor input_desc; ScopedTensorDescriptor output_desc; ScopedPoolingDescriptor pool_desc; - DataLayout layout = DataLayout::kNCHW; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); @@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel { ScopedTensorDescriptor input_desc; ScopedTensorDescriptor output_desc; ScopedPoolingDescriptor pool_desc; - DataLayout layout = DataLayout::kNCHW; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); @@ -150,5 +162,12 @@ class PoolCudnnGradOpKernel : public framework::OpKernel { namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel); -REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel); +REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel, + ops::PoolCudnnOpKernel); +REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel, + ops::PoolCudnnGradOpKernel); + +REGISTER_OP_GPU_KERNEL(pool3d_cudnn, ops::PoolCudnnOpKernel, + ops::PoolCudnnOpKernel); +REGISTER_OP_GPU_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel, + ops::PoolCudnnGradOpKernel); diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index f3963b1995ef8767786f0bf230b134afc69aa99d..d8c58618cf703d086d3cabc927ebc5eb038b1aec 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -217,14 +217,18 @@ REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad, ops::PoolOpGrad); REGISTER_OP_CPU_KERNEL(pool2d, - ops::PoolKernel); + ops::PoolKernel, + ops::PoolKernel); REGISTER_OP_CPU_KERNEL(pool2d_grad, - ops::PoolGradKernel) + ops::PoolGradKernel, + ops::PoolGradKernel) REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad, ops::PoolOpGrad); REGISTER_OP_CPU_KERNEL(pool3d, - ops::PoolKernel); + ops::PoolKernel, + ops::PoolKernel); REGISTER_OP_CPU_KERNEL(pool3d_grad, - ops::PoolGradKernel); + ops::PoolGradKernel, + ops::PoolGradKernel); diff --git a/paddle/operators/pool_op.cu.cc b/paddle/operators/pool_op.cu.cc index 0e3b80868f7b9d1697d619889160856d65ad59a3..1010cb762289dd39cd632c699f7528f4ba638278 100644 --- a/paddle/operators/pool_op.cu.cc +++ b/paddle/operators/pool_op.cu.cc @@ -17,11 +17,15 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(pool2d, - ops::PoolKernel); + ops::PoolKernel, + ops::PoolKernel); REGISTER_OP_GPU_KERNEL(pool2d_grad, - ops::PoolGradKernel); + ops::PoolGradKernel, + ops::PoolGradKernel); REGISTER_OP_GPU_KERNEL(pool3d, - ops::PoolKernel); + ops::PoolKernel, + ops::PoolKernel); REGISTER_OP_GPU_KERNEL(pool3d_grad, - ops::PoolGradKernel); + ops::PoolGradKernel, + ops::PoolGradKernel); diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index 1df36e965abab3549aeb88bf682b712033c4d79c..4b95c7ef6b41a9e7d2fe4654ea9acd3437757f66 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -250,10 +250,12 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp, REGISTER_OP_CPU_KERNEL( max_pool2d_with_index, - ops::MaxPoolWithIndexKernel); + ops::MaxPoolWithIndexKernel, + ops::MaxPoolWithIndexKernel); REGISTER_OP_CPU_KERNEL( max_pool2d_with_index_grad, - ops::MaxPoolWithIndexGradKernel) + ops::MaxPoolWithIndexGradKernel, + ops::MaxPoolWithIndexGradKernel) REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad, @@ -261,7 +263,9 @@ REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, REGISTER_OP_CPU_KERNEL( max_pool3d_with_index, - ops::MaxPoolWithIndexKernel); + ops::MaxPoolWithIndexKernel, + ops::MaxPoolWithIndexKernel); REGISTER_OP_CPU_KERNEL( max_pool3d_with_index_grad, - ops::MaxPoolWithIndexGradKernel) + ops::MaxPoolWithIndexGradKernel, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.cu.cc b/paddle/operators/pool_with_index_op.cu.cc index 287657d4b1c57f354ef050885f71261092bdc062..8764a71da0877a3e43fc9cab0b833199bb661962 100644 --- a/paddle/operators/pool_with_index_op.cu.cc +++ b/paddle/operators/pool_with_index_op.cu.cc @@ -18,14 +18,18 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( max_pool2d_with_index, - ops::MaxPoolWithIndexKernel); + ops::MaxPoolWithIndexKernel, + ops::MaxPoolWithIndexKernel); REGISTER_OP_GPU_KERNEL( max_pool2d_with_index_grad, - ops::MaxPoolWithIndexGradKernel) + ops::MaxPoolWithIndexGradKernel, + ops::MaxPoolWithIndexGradKernel) REGISTER_OP_GPU_KERNEL( max_pool3d_with_index, - ops::MaxPoolWithIndexKernel); + ops::MaxPoolWithIndexKernel, + ops::MaxPoolWithIndexKernel); REGISTER_OP_GPU_KERNEL( max_pool3d_with_index_grad, - ops::MaxPoolWithIndexGradKernel) + ops::MaxPoolWithIndexGradKernel, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/sequence_slice_op.cc b/paddle/operators/sequence_slice_op.cc new file mode 100755 index 0000000000000000000000000000000000000000..cbe0b4233160dd1f3ebdf6db8b5f6df392efdfe7 --- /dev/null +++ b/paddle/operators/sequence_slice_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_slice_op.h" + +namespace paddle { +namespace operators { + +class SequenceSliceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Offset"), + "Input(Offset) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Length"), + "Input(Length) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceSliceOp should not be null."); + auto input_dims = ctx->GetInputDim("X"); + + auto offset_dim = ctx->GetInputDim("Offset"); + auto length_dim = ctx->GetInputDim("Length"); + + PADDLE_ENFORCE_EQ( + offset_dim.size(), 2UL, + "Only support one level sequence now, The rank of offset must be 2."); + PADDLE_ENFORCE_EQ( + length_dim.size(), 2UL, + "Only support one level sequence now, The rank of Length must be 2."); + + // Initialize the output's dims to maximum, + // and re-set to real dims by the value of Offset and Length at kernel + ctx->SetOutputDim("Out", input_dims); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class SequenceSliceGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), + "The gradient of X should not be null."); + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceSliceOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor), " + "the input of SequenceSliceOp."); + AddInput("Offset", + "(Tensor), " + "a vector to describe the offset of every input sequence for " + "sub sequence item."); + AddInput("Length", + "(Tensor), " + "a vector to describe the length of every input sequence for " + "sub sequence item."); + AddOutput("Out", + "(LoDTensor), the output of SequenceSliceOp."); + AddComment(R"DOC( +Sequence slice operator + +The operator crops a subsequence from given sequence with given start offset and subsequence length. +It only supports sequence (LoD Tensor with level number is 1). +- Case: + X = [[a1, a2; + b1, b2; + c1, c2] + [d1, d2; + e1, e2]] + LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2) + Offset = [[0], [1]]; Length = [[2], [1]] + + Out = [[a1, a2; + b1, b2] + [e1, e2]] + LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2) +NOTE: The first dimension size of input, the size of offset and Length, should be equal. The offset start from 0. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sequence_slice, ops::SequenceSliceOp, ops::SequenceSliceOpMaker, + sequence_slice_grad, ops::SequenceSliceGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_slice, + ops::SequenceSliceOpKernel); +REGISTER_OP_CPU_KERNEL( + sequence_slice_grad, + ops::SequenceSliceGradOpKernel); diff --git a/paddle/operators/sequence_slice_op.cu b/paddle/operators/sequence_slice_op.cu new file mode 100755 index 0000000000000000000000000000000000000000..a9f59dadba74d900fa5cc0601fb5b264ea19e34d --- /dev/null +++ b/paddle/operators/sequence_slice_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_slice_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + sequence_slice, + ops::SequenceSliceOpKernel); +REGISTER_OP_GPU_KERNEL( + sequence_slice_grad, + ops::SequenceSliceGradOpKernel); diff --git a/paddle/operators/sequence_slice_op.h b/paddle/operators/sequence_slice_op.h new file mode 100755 index 0000000000000000000000000000000000000000..2c9b8464a1236a054cf1a38b9dc1d73588f8dd38 --- /dev/null +++ b/paddle/operators/sequence_slice_op.h @@ -0,0 +1,173 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +template +inline LoD SequenceSliceLoD(const T& in, const int64_t* offset_data, + const int64_t* length_data) { + auto out_lod = in.lod(); + size_t lod_offset = 0; + + auto n = in.lod()[0].size() - 1; + out_lod[0][0] = 0; + for (size_t i = 0; i < n; ++i) { + lod_offset += length_data[i]; + out_lod[0][i+1] = lod_offset; + } + return out_lod; +} + +template +class SequenceSliceOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* offset = ctx.Input("Offset"); + auto* length = ctx.Input("Length"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + auto n = lod[0].size() - 1; + + PADDLE_ENFORCE_EQ(lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ( + n, static_cast(length->dims()[0]), + "The size of input-sequence and length-array should be the same") + PADDLE_ENFORCE_EQ( + n, static_cast(offset->dims()[0]), + "The size of input-sequence and offset-array should be the same") + + const int64_t* offset_data = offset->data(); + const int64_t* length_data = length->data(); + framework::Tensor offset_cpu; + framework::Tensor length_cpu; + + if (platform::is_gpu_place(ctx.GetPlace())) { + offset_cpu.mutable_data(offset->dims(), platform::CPUPlace()); + offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context()); + offset_data = offset_cpu.data(); + + length_cpu.mutable_data(length->dims(), platform::CPUPlace()); + length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context()); + length_data = length_cpu.data(); + } + + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_LT(0, offset_data[i], + "The offset[%d] must greater than zero.", i) + PADDLE_ENFORCE_LT(0, length_data[i], + "The length[%d] must greater than zero.", i) + PADDLE_ENFORCE_LT( + lod[0][i] + offset_data[i] + length_data[i], + lod[0][i + 1], + "The target tensor's length overflow.") + } + + out->mutable_data(ctx.GetPlace()); + auto out_lod = SequenceSliceLoD(*in, offset_data, length_data); + auto out_dims = in->dims(); + out_dims[0] = out_lod[0][out_lod[0].size() - 1]; + out->Resize(out_dims); + out->set_lod(out_lod); + + auto in_stride = framework::stride(in->dims()); + auto out_stride = framework::stride(out->dims()); + + size_t out_offset = 0; + for (size_t i = 0; i < n; ++i) { + Tensor in_t = + in->Slice(static_cast(lod[0][i] + offset_data[i]), + static_cast(lod[0][i] + offset_data[i] + + length_data[i])); + + StridedMemcpy(ctx.device_context(), in_t.data(), + in_stride, in_t.dims(), out_stride, + out->data() + out_offset); + out_offset += length_data[i] * in_stride[0]; + } + } +}; + +template +class SequenceSliceGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* offset = ctx.Input("Offset"); + auto* length = ctx.Input("Length"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* x_grad = + ctx.Output(framework::GradVarName("X")); + + const int64_t* offset_data = offset->data(); + const int64_t* length_data = length->data(); + framework::Tensor offset_cpu; + framework::Tensor length_cpu; + + if (platform::is_gpu_place(ctx.GetPlace())) { + offset_cpu.mutable_data(offset->dims(), platform::CPUPlace()); + offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context()); + offset_data = offset_cpu.data(); + + length_cpu.mutable_data(length->dims(), platform::CPUPlace()); + length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context()); + length_data = length_cpu.data(); + } + + auto lod = in->lod(); + auto out_lod = out_grad->lod(); + + if (x_grad) { + x_grad->mutable_data(ctx.GetPlace()); + x_grad->set_lod(in->lod()); + math::SetConstant set_zero; + set_zero(ctx.device_context(), x_grad, static_cast(0)); + + auto out_grad_stride = framework::stride(out_grad->dims()); + + for (size_t i = 0; i < out_lod[0].size() - 1; ++i) { + Tensor out_grad_t = + out_grad->Slice(static_cast(out_lod[0][i]), + static_cast(out_lod[0][i + 1])); + auto out_grad_stride = framework::stride(out_grad_t.dims()); + + auto x_grad_stride = framework::stride(x_grad->dims()); + + Tensor x_grad_t = x_grad->Slice( + static_cast(lod[0][i] + offset_data[i]), + static_cast(lod[0][i] + offset_data[i] + length_data[i])); + + StridedMemcpy(ctx.device_context(), out_grad_t.data(), + out_grad_stride, out_grad_t.dims(), x_grad_stride, + x_grad_t.data()); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/cudnn_helper.h b/paddle/platform/cudnn_helper.h index dd48605b9ed688e4656d4cd1ddf1f298d0a50a9e..c5d8a6066ef3becb601344590f977a38c2af0a63 100644 --- a/paddle/platform/cudnn_helper.h +++ b/paddle/platform/cudnn_helper.h @@ -224,13 +224,15 @@ class ScopedConvolutionDescriptor { PADDLE_ENFORCE_EQ(pads.size(), strides.size()); PADDLE_ENFORCE_EQ(pads.size(), dilations.size()); -#if CUDNN_VERSION < 6000 +#if !CUDNN_VERSION_MIN(6, 0, 0) // cudnn v5 does not support dilation conv, the argument is called upscale // instead of dilations and it is must be one. for (size_t i = 0; i < dilations.size(); ++i) { PADDLE_ENFORCE_EQ( dilations[i], 1, - "Dilations conv is not supported in this cuDNN version"); + "Dilations conv is not supported in this cuDNN version(%d.%d.%d).", + CUDNN_VERSION / 1000, CUDNN_VERSION % 1000 / 100, + CUDNN_VERSION % 100); } #endif diff --git a/paddle/platform/cudnn_helper_test.cc b/paddle/platform/cudnn_helper_test.cc index 6bd85ae1ca8b47b203e0321e9d9224d5cfd3a586..427359f69713b961c4730b697d3ccde5f7085838 100644 --- a/paddle/platform/cudnn_helper_test.cc +++ b/paddle/platform/cudnn_helper_test.cc @@ -38,6 +38,26 @@ TEST(CudnnHelper, ScopedTensorDescriptor) { EXPECT_EQ(strides[2], 6); EXPECT_EQ(strides[1], 36); EXPECT_EQ(strides[0], 144); + + // test tensor5d: ScopedTensorDescriptor + ScopedTensorDescriptor tensor5d_desc; + std::vector shape_5d = {2, 4, 6, 6, 6}; + auto desc_5d = tensor5d_desc.descriptor(DataLayout::kNCDHW, shape_5d); + + std::vector dims_5d(5); + std::vector strides_5d(5); + paddle::platform::dynload::cudnnGetTensorNdDescriptor( + desc_5d, 5, &type, &nd, dims_5d.data(), strides_5d.data()); + + EXPECT_EQ(nd, 5); + for (size_t i = 0; i < dims_5d.size(); ++i) { + EXPECT_EQ(dims_5d[i], shape_5d[i]); + } + EXPECT_EQ(strides_5d[4], 1); + EXPECT_EQ(strides_5d[3], 6); + EXPECT_EQ(strides_5d[2], 36); + EXPECT_EQ(strides_5d[1], 216); + EXPECT_EQ(strides_5d[0], 864); } TEST(CudnnHelper, ScopedFilterDescriptor) { @@ -60,6 +80,20 @@ TEST(CudnnHelper, ScopedFilterDescriptor) { for (size_t i = 0; i < shape.size(); ++i) { EXPECT_EQ(kernel[i], shape[i]); } + + ScopedFilterDescriptor filter_desc_4d; + std::vector shape_4d = {2, 3, 3, 3}; + auto desc_4d = filter_desc.descriptor(DataLayout::kNCDHW, shape_4d); + + std::vector kernel_4d(4); + paddle::platform::dynload::cudnnGetFilterNdDescriptor( + desc_4d, 4, &type, &format, &nd, kernel_4d.data()); + + EXPECT_EQ(GetCudnnTensorFormat(DataLayout::kNCHW), format); + EXPECT_EQ(nd, 4); + for (size_t i = 0; i < shape_4d.size(); ++i) { + EXPECT_EQ(kernel_4d[i], shape_4d[i]); + } } TEST(CudnnHelper, ScopedConvolutionDescriptor) { diff --git a/python/paddle/trainer_config_helpers/activations.py b/python/paddle/trainer_config_helpers/activations.py index b4c6e7fc305adacdaa64fad857474ea1ee2cdc97..00efc01c0592107314f5b23c951706d039d49a88 100644 --- a/python/paddle/trainer_config_helpers/activations.py +++ b/python/paddle/trainer_config_helpers/activations.py @@ -256,7 +256,7 @@ class SoftSignActivation(BaseActivation): SoftSign Activation. .. math:: - f(z)=\\frac{1}{1 + |z|} + f(z)=\\frac{z}{1 + |z|} """ def __init__(self): diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index d323d34c3ff47614342934c2a02492f66d27dc10..9776ae18057d57dd994fac8b62090258252922c6 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import math from activations import LinearActivation, ReluActivation, SoftmaxActivation, \ IdentityActivation, TanhActivation, SequenceSoftmaxActivation @@ -26,9 +26,9 @@ __all__ = [ 'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool", "img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg', 'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru', - 'simple_attention', 'dot_product_attention', 'simple_gru2', - 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs', - 'outputs' + 'simple_attention', 'dot_product_attention', 'multi_head_attention', + 'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', + 'inputs', 'outputs' ] ###################################################### @@ -1476,10 +1476,8 @@ def dot_product_attention(encoded_sequence, expand_as=encoded_sequence, name='%s_expand' % name) - m = linear_comb_layer( - weights=expanded, - vectors=encoded_sequence, - name='%s_dot-product' % name) + m = dot_prod_layer( + input1=expanded, input2=encoded_sequence, name='%s_dot-product' % name) attention_weight = fc_layer( input=m, @@ -1498,6 +1496,134 @@ def dot_product_attention(encoded_sequence, input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name) +@wrap_name_default() +def multi_head_attention(query, + key, + value, + key_proj_size, + value_proj_size, + head_num, + attention_type, + softmax_param_attr=None, + name=None): + """ + Calculate and return a context vector with dot-product attention mechanism. + The dimension of the context vector equals to value_proj_size * head_num. + + Please refer to **Attention Is All You Need** for more details. The link is + as follows: + https://arxiv.org/abs/1706.03762. + + The example usage is: + + .. code-block:: python + + context = multi_head_attention(query=decoder_state, + key=enc_seq, + value=enc_seq, + key_proj_size=64, + value_pro_size=64, + head_num=8, + attention_type='dot-product attention') + + :param name: A prefix attached to the name of each layer that defined inside + the multi_head_attention. + :type name: basestring + :param softmax_param_attr: The parameter attribute of sequence softmax + that is used to produce attention weight. + :type softmax_param_attr: ParameterAttribute + :param query: query is used to calculate attention weights over values at current step. + :type query: LayerOutput + :param key: key is used to calculate the attention weight of the corresponding value. + :type key: LayerOutput + :param value: value is the sequence to be attended. + :type value: LayerOutput + :param key_proj_size: The dimension of the linear projection performed on key and query. + :type key_proj_size: int + :param value_proj_size: The dimension of the linear projection performed on value. + :type value_proj_size: int + :param head_num: The number of attention heads. + :type head_num: int + :param attention_type: The type of the attention mechanism used in each attention + heads. Now, we only support scaled dot-product attention and + additive attention. + :type attention_type: basestring + :return: The context vector. + :rtype: LayerOutput + """ + assert attention_type in ['dot-product attention', 'additive attention'] + + with mixed_layer( + size=key_proj_size * head_num, + name='%s_query_proj' % name) as query_proj: + query_proj += full_matrix_projection(query) + query_proj = expand_layer(input=query_proj, expand_as=key) + + with mixed_layer( + size=key_proj_size * head_num, + name='%s_key_proj' % name) as key_proj: + key_proj += full_matrix_projection(key) + + with mixed_layer( + size=value_proj_size * head_num, + name='%s_value_proj' % name) as value_proj: + value_proj += full_matrix_projection(value) + + head_list = [] + for i in range(head_num): + with mixed_layer(size=key_proj_size) as sub_query_proj: + sub_query_proj += identity_projection( + query_proj, offset=key_proj_size * i, size=key_proj_size) + + with mixed_layer(size=key_proj_size) as sub_key_proj: + sub_key_proj += identity_projection( + key_proj, offset=key_proj_size * i, size=key_proj_size) + + with mixed_layer(size=value_proj_size) as sub_value_proj: + sub_value_proj += identity_projection( + value_proj, offset=value_proj_size * i, size=value_proj_size) + + if attention_type == 'dot-product attention': + m = dot_prod_layer( + input1=sub_query_proj, + input2=sub_key_proj, + name='%s_dot-product_%d' % (name, i)) + m = slope_intercept_layer( + input=m, + slope=math.sqrt(1.0 / key_proj_size), + name='%s_dot-product_scaling_%d' % (name, i)) + else: + with mixed_layer( + size=key_proj_size, + act=TanhActivation(), + name='%s_combine_%d' % (name, i)) as m: + m += identity_projection(sub_query_proj) + m += identity_projection(sub_key_proj) + + attention_weight = fc_layer( + input=m, + size=1, + act=SequenceSoftmaxActivation(), + param_attr=softmax_param_attr, + name="%s_softmax_%d" % (name, i), + bias_attr=False) + + scaled = scaling_layer( + weight=attention_weight, + input=sub_value_proj, + name='%s_scaling_%d' % (name, i)) + head = pooling_layer( + input=scaled, + pooling_type=SumPooling(), + name="%s_pooling_%d" % (name, i)) + + head_list.append(head) + + attended = concat_layer(head_list) + + return attended + + def inputs(layers, *args): """ Declare the inputs of network. The order of input should be as same as diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 6c09ca3d346e36b6f38ac5c89914e95a7383f719..a21f67a2d99e7eab39708e2a571d30d7e9f20ce6 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -11,7 +11,6 @@ test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_l test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer -test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/v2/fluid/tests/test_pool2d_op.py b/python/paddle/v2/fluid/tests/test_pool2d_op.py index ac3fa6aa87835b3cd6fb9bbf6fe66b1d0c577ca2..5dff6270f455395ce6ca8ae2428236f630467095 100644 --- a/python/paddle/v2/fluid/tests/test_pool2d_op.py +++ b/python/paddle/v2/fluid/tests/test_pool2d_op.py @@ -3,8 +3,7 @@ import numpy as np from op_test import OpTest -def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): - +def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): N, C, H, W = x.shape if global_pool == 1: ksize = [H, W] @@ -23,8 +22,7 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): return out -def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): - +def avg_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): N, C, H, W = x.shape if global_pool == 1: ksize = [H, W] @@ -47,6 +45,7 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): class TestPool2d_Op(OpTest): def setUp(self): self.init_test_case() + self.init_global_pool() self.init_op_type() self.init_pool_type() if self.global_pool: @@ -75,8 +74,6 @@ class TestPool2d_Op(OpTest): self.check_grad(set(['X']), 'Out', max_relative_error=0.07) def init_test_case(self): - self.global_pool = True - self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 5, 5] self.ksize = [3, 3] self.strides = [1, 1] @@ -87,12 +84,14 @@ class TestPool2d_Op(OpTest): def init_pool_type(self): self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = True class TestCase1(TestPool2d_Op): def init_test_case(self): - self.global_pool = False - self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] @@ -103,12 +102,14 @@ class TestCase1(TestPool2d_Op): def init_pool_type(self): self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False class TestCase2(TestPool2d_Op): def init_test_case(self): - self.global_pool = False - self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] @@ -119,152 +120,69 @@ class TestCase2(TestPool2d_Op): def init_pool_type(self): self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + def init_global_pool(self): + self.global_pool = False -class TestCase3(TestPool2d_Op): - def init_test_case(self): - self.global_pool = True - self.pool2D_forward_naive = max_pool2D_forward_naive - self.shape = [2, 3, 5, 5] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [0, 0] +class TestCase3(TestPool2d_Op): def init_op_type(self): self.op_type = "pool2d" def init_pool_type(self): self.pool_type = "max" - - -class TestCase4(TestPool2d_Op): - def init_test_case(self): - self.global_pool = False self.pool2D_forward_naive = max_pool2D_forward_naive - self.shape = [2, 3, 7, 7] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [0, 0] + +class TestCase4(TestCase1): def init_op_type(self): self.op_type = "pool2d" def init_pool_type(self): self.pool_type = "max" - - -class TestCase5(TestPool2d_Op): - def init_test_case(self): - self.global_pool = False self.pool2D_forward_naive = max_pool2D_forward_naive - self.shape = [2, 3, 7, 7] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [1, 1] + +class TestCase5(TestCase2): def init_op_type(self): self.op_type = "pool2d" def init_pool_type(self): self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive #--------------------test pool2d_cudnn-------------------- -class TestCaseCudnn1(TestPool2d_Op): - def init_test_case(self): - self.global_pool = True - self.pool2D_forward_naive = avg_pool2D_forward_naive - self.shape = [2, 3, 5, 5] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [0, 0] - +class TestCudnnCase1(TestPool2d_Op): def init_op_type(self): self.op_type = "pool2d_cudnn" - def init_pool_type(self): - self.pool_type = "avg" - - -class TestCaseCudnn2(TestPool2d_Op): - def init_test_case(self): - self.global_pool = False - self.pool2D_forward_naive = avg_pool2D_forward_naive - self.shape = [2, 3, 7, 7] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [0, 0] +class TestCudnnCase2(TestCase1): def init_op_type(self): self.op_type = "pool2d_cudnn" - def init_pool_type(self): - self.pool_type = "avg" - - -class TestCaseCudnn3(TestPool2d_Op): - def init_test_case(self): - self.global_pool = False - self.pool2D_forward_naive = avg_pool2D_forward_naive - self.shape = [2, 3, 7, 7] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [1, 1] +class TestCudnnCase3(TestCase2): def init_op_type(self): self.op_type = "pool2d_cudnn" - def init_pool_type(self): - self.pool_type = "avg" - - -class TestCaseCudnn4(TestPool2d_Op): - def init_test_case(self): - self.global_pool = True - self.pool2D_forward_naive = max_pool2D_forward_naive - self.shape = [2, 3, 5, 5] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [0, 0] +class TestCudnnCase4(TestCase3): def init_op_type(self): self.op_type = "pool2d_cudnn" - def init_pool_type(self): - self.pool_type = "max" - - -class TestCaseCudnn5(TestPool2d_Op): - def init_test_case(self): - self.global_pool = False - self.pool2D_forward_naive = max_pool2D_forward_naive - self.shape = [2, 3, 7, 7] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [0, 0] +class TestCudnnCase5(TestCase4): def init_op_type(self): self.op_type = "pool2d_cudnn" - def init_pool_type(self): - self.pool_type = "max" - - -class TestCaseCudnn6(TestPool2d_Op): - def init_test_case(self): - self.global_pool = False - self.pool2D_forward_naive = max_pool2D_forward_naive - self.shape = [2, 3, 7, 7] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [1, 1] +class TestCudnnCase6(TestCase5): def init_op_type(self): self.op_type = "pool2d_cudnn" - def init_pool_type(self): - self.pool_type = "max" - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_pool3d_op.py b/python/paddle/v2/fluid/tests/test_pool3d_op.py index 87483ae5e568c01141ff789f37e84069cb8e827d..2ba86665a7d207e61159c02643fa40daca3be080 100644 --- a/python/paddle/v2/fluid/tests/test_pool3d_op.py +++ b/python/paddle/v2/fluid/tests/test_pool3d_op.py @@ -3,8 +3,7 @@ import numpy as np from op_test import OpTest -def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): - +def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): N, C, D, H, W = x.shape if global_pool == 1: ksize = [D, H, W] @@ -27,8 +26,7 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): return out -def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): - +def avg_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): N, C, D, H, W = x.shape if global_pool == 1: ksize = [D, H, W] @@ -55,6 +53,10 @@ def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): class TestPool3d_Op(OpTest): def setUp(self): self.init_test_case() + self.init_global_pool() + self.init_op_type() + self.init_pool_type() + if self.global_pool: self.paddings = [0 for _ in range(len(self.paddings))] input = np.random.random(self.shape).astype("float32") @@ -81,74 +83,115 @@ class TestPool3d_Op(OpTest): self.check_grad(set(['X']), 'Out', max_relative_error=0.07) def init_test_case(self): - self.global_pool = True - self.op_type = "pool3d" - self.pool_type = "avg" - self.pool3D_forward_naive = avg_pool3D_forward_naive self.shape = [2, 3, 5, 5, 5] self.ksize = [3, 3, 3] self.strides = [1, 1, 1] self.paddings = [0, 0, 0] + def init_op_type(self): + self.op_type = "pool3d" + + def init_pool_type(self): + self.pool_type = "avg" + self.pool3D_forward_naive = avg_pool3D_forward_naive + + def init_global_pool(self): + self.global_pool = True + class TestCase1(TestPool3d_Op): def init_test_case(self): - self.global_pool = False self.op_type = "pool3d" - self.pool_type = "avg" - self.pool3D_forward_naive = avg_pool3D_forward_naive self.shape = [2, 3, 7, 7, 7] self.ksize = [3, 3, 3] self.strides = [1, 1, 1] self.paddings = [0, 0, 0] - -class TestCase2(TestPool3d_Op): - def init_test_case(self): - self.global_pool = False + def init_op_type(self): self.op_type = "pool3d" + + def init_pool_type(self): self.pool_type = "avg" self.pool3D_forward_naive = avg_pool3D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + +class TestCase2(TestPool3d_Op): + def init_test_case(self): self.shape = [2, 3, 7, 7, 7] self.ksize = [3, 3, 3] self.strides = [1, 1, 1] self.paddings = [1, 1, 1] + def init_op_type(self): + self.op_type = "pool3d" + + def init_pool_type(self): + self.pool_type = "avg" + self.pool3D_forward_naive = avg_pool3D_forward_naive + + def init_global_pool(self): + self.global_pool = False + class TestCase3(TestPool3d_Op): - def init_test_case(self): - self.global_pool = True + def init_op_type(self): self.op_type = "pool3d" + + def init_pool_type(self): self.pool_type = "max" self.pool3D_forward_naive = max_pool3D_forward_naive - self.shape = [2, 3, 5, 5, 5] - self.ksize = [3, 3, 3] - self.strides = [1, 1, 1] - self.paddings = [0, 0, 0] -class TestCase4(TestPool3d_Op): - def init_test_case(self): - self.global_pool = False +class TestCase4(TestCase1): + def init_op_type(self): self.op_type = "pool3d" + + def init_pool_type(self): self.pool_type = "max" self.pool3D_forward_naive = max_pool3D_forward_naive - self.shape = [2, 3, 7, 7, 7] - self.ksize = [3, 3, 3] - self.strides = [1, 1, 1] - self.paddings = [0, 0, 0] -class TestCase5(TestPool3d_Op): - def init_test_case(self): - self.global_pool = False +class TestCase5(TestCase2): + def init_op_type(self): self.op_type = "pool3d" + + def init_pool_type(self): self.pool_type = "max" self.pool3D_forward_naive = max_pool3D_forward_naive - self.shape = [2, 3, 7, 7, 7] - self.ksize = [3, 3, 3] - self.strides = [1, 1, 1] - self.paddings = [1, 1, 1] + + +#--------------------test pool3d_cudnn-------------------- +class TestCudnnCase1(TestPool3d_Op): + def init_op_type(self): + self.op_type = "pool3d_cudnn" + + +class TestCudnnCase2(TestCase1): + def init_op_type(self): + self.op_type = "pool3d_cudnn" + + +class TestCudnnCase3(TestCase2): + def init_op_type(self): + self.op_type = "pool3d_cudnn" + + +class TestCudnnCase4(TestCase3): + def init_op_type(self): + self.op_type = "pool3d_cudnn" + + +class TestCudnnCase5(TestCase4): + def init_op_type(self): + self.op_type = "pool3d_cudnn" + + +class TestCudnnCase6(TestCase5): + def init_op_type(self): + self.op_type = "pool3d_cudnn" if __name__ == '__main__': diff --git a/python/paddle/v2/fluid/tests/test_sequence_slice_op.py b/python/paddle/v2/fluid/tests/test_sequence_slice_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd9a05343b0c4aa05b258959665c0662f271512 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_slice_op.py @@ -0,0 +1,47 @@ +import unittest +import numpy as np +import sys +from op_test import OpTest + + +class TestSequenceSliceOp(OpTest): + def set_data(self): + self.init_test_case() + # only supprot one level LoD + x = np.random.random(self.x_dim).astype('float32') + lod = self.x_lod + offset = np.array(self.offset).astype("int64") + length = np.array(self.length).astype("int64") + + self.inputs = {'X': (x, lod), 'Offset': offset, 'Length': length} + outs = [] #np.zeros((100, 3, 2)).astype('float32') + out_lod = [[0]] + out_lod_offset = 0 + for i in range(len(offset)): + sub_x = x[lod[0][i] + offset[i, 0]:lod[0][i] + offset[i, 0] + + length[i, 0], :] + out_lod_offset = out_lod_offset + len(sub_x) + outs.append(sub_x) + out_lod[0].append(out_lod_offset) + outs = np.concatenate(outs, axis=0) + self.outputs = {'Out': (outs, out_lod)} + + def init_test_case(self): + self.x_dim = (100, 3, 2) + self.x_lod = [[0, 20, 40, 60, 80, 100]] + self.offset = [[1], [2], [3], [4], [5]] + self.length = [[10], [8], [6], [4], [2]] + + def setUp(self): + self.op_type = "sequence_slice" + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main()