提交 afc6343e 编写于 作者: D dangqingqing

Refine sequence max-pooling and add unit testing of gradient check.

上级 dfe851a0
......@@ -141,6 +141,7 @@ set(DEPS_OPS
pool_with_index_op
nccl_op
sequence_conv_op
sequence_pool_op
lstm_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
......@@ -153,6 +154,7 @@ if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
endif()
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS net_op tensor_array)
......
......@@ -8,6 +8,7 @@ if(WITH_GPU)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
......@@ -18,6 +19,7 @@ else()
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context)
nv_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_pooling.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class MaxSeqPoolFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
auto idx_dims = index->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1UL);
PADDLE_ENFORCE_GT(out_dims.size(), 1UL);
for (size_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
}
PADDLE_ENFORCE_EQ(idx_dims, out_dims);
auto starts = input.lod()[0];
const T* in_data = input.data<T>();
T* out_data = output->data<T>();
int* max_index = index->data<int>();
int64_t num_seq = out_dims[0];
int64_t dim = output->numel() / num_seq;
for (int64_t i = 0; i < num_seq; ++i) {
for (int64_t k = 0; k < dim; ++k) {
out_data[i * dim + k] = in_data[starts[i] * dim + k];
max_index[i * dim + k] = starts[i];
}
for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) {
for (int64_t k = 0; k < dim; ++k) {
if (in_data[j * dim + k] > out_data[i * dim + k]) {
out_data[i * dim + k] = in_data[j * dim + k];
max_index[i * dim + k] = j;
}
}
}
}
}
};
template <typename T>
class MaxSeqPoolGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& out_grad,
const framework::Tensor& index,
framework::LoDTensor* in_grad) {
auto og_dims = out_grad.dims();
auto ig_dims = in_grad->dims();
auto idx_dims = index.dims();
PADDLE_ENFORCE_GT(og_dims.size(), 1UL);
PADDLE_ENFORCE_GT(ig_dims.size(), 1UL);
for (size_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]);
}
PADDLE_ENFORCE_EQ(idx_dims, og_dims);
const T* og_data = out_grad.data<T>();
const int* max_index = index.data<int>();
T* ig_data = in_grad->data<T>();
SetConstant<platform::CPUPlace, T> set_zero;
set_zero(context, in_grad, static_cast<T>(0.0));
int64_t num_seq = og_dims[0];
int64_t dim = out_grad.numel() / num_seq;
for (size_t i = 0; i < num_seq; ++i) {
for (size_t j = 0; j < dim; ++j) {
int step_id = max_index[i * dim + j];
ig_data[step_id * dim + j] = og_data[i * dim + j];
}
}
}
};
template class MaxSeqPoolFunctor<platform::CPUPlace, float>;
template class MaxSeqPoolFunctor<platform::CPUPlace, double>;
template class MaxSeqPoolGradFunctor<platform::CPUPlace, float>;
template class MaxSeqPoolGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_pooling.h"
namespace paddle {
namespace operators {
namespace math {
#define FLT_MAX __FLT_MAX__
template <typename T>
__global__ void KeMaxSequencePool(const T* input, const size_t* starts,
T* output, int* index, int64_t num_seq,
int64_t dim) {
int dim_idx = threadIdx.x;
int seq_id = blockIdx.x;
if (seq_id >= num_seq) return;
size_t start = starts[seq_id];
size_t end = starts[seq_id + 1];
for (int i = dim_idx; i < dim; i += blockDim.x) {
T max_val = static_cast<T>(-FLT_MAX);
int max_id = -1;
for (size_t step_id = start; step_id < end; step_id++) {
if (max_val < input[step_id * dim + i]) {
max_val = input[step_id * dim + i];
max_id = step_id;
}
}
output[seq_id * dim + i] = max_val;
index[seq_id * dim + i] = max_id;
}
}
template <typename T>
class MaxSeqPoolFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
auto idx_dims = index->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1UL);
PADDLE_ENFORCE_GT(out_dims.size(), 1UL);
for (size_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
}
PADDLE_ENFORCE_EQ(idx_dims, out_dims);
auto starts = input.lod()[0];
const T* in_data = input.data<T>();
T* out_data = output->data<T>();
int* max_index = index->data<int>();
int64_t num_seq = out_dims[0];
int64_t dim = output->numel() / num_seq;
dim3 threads(256, 1);
dim3 grid(num_seq, 1);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
KeMaxSequencePool<T><<<grid, threads, 0, stream>>>(
in_data, starts.data(), out_data, max_index, num_seq, dim);
}
};
template <typename T>
__global__ void KeMaxSequencePoolGrad(const T* out_grad, const int* max_index,
T* in_grad, int64_t num_seq,
int64_t dim) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int col_idx = idx % dim;
if (idx < num_seq * dim) {
int step_id = max_index[idx];
in_grad[step_id * dim + col_idx] = out_grad[idx];
}
}
template <typename T>
class MaxSeqPoolGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& out_grad,
const framework::Tensor& index,
framework::LoDTensor* in_grad) {
auto og_dims = out_grad.dims();
auto idx_dims = index.dims();
auto ig_dims = in_grad->dims();
PADDLE_ENFORCE_GT(og_dims.size(), 1UL);
PADDLE_ENFORCE_GT(ig_dims.size(), 1UL);
for (size_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]);
}
PADDLE_ENFORCE_EQ(idx_dims, og_dims);
const T* og_data = out_grad.data<T>();
const int* max_index = index.data<int>();
T* ig_data = in_grad->data<T>();
SetConstant<platform::GPUPlace, T> set_zero;
set_zero(context, in_grad, static_cast<T>(0.0));
int64_t num_seq = og_dims[0];
int64_t dim = out_grad.numel() / num_seq;
unsigned int blocks = (num_seq * dim + 128 - 1) / 128;
dim3 threads(128, 1);
dim3 grid(blocks, 1);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
KeMaxSequencePoolGrad<T><<<grid, threads, 0, stream>>>(
og_data, max_index, ig_data, num_seq, dim);
}
};
template class MaxSeqPoolFunctor<platform::GPUPlace, float>;
template class MaxSeqPoolFunctor<platform::GPUPlace, double>;
template class MaxSeqPoolGradFunctor<platform::GPUPlace, float>;
template class MaxSeqPoolGradFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
#define FLT_MAX __FLT_MAX__
template <typename Place, typename T>
class MaxSeqPoolFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index);
};
template <typename Place, class T>
class MaxSeqPoolGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& out_grad,
const framework::Tensor& index,
framework::LoDTensor* in_grad);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -27,6 +27,11 @@ class SequencePoolOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequencePoolOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (ctx->Attrs().Get<std::string>("pooltype") == "MAX") {
PADDLE_ENFORCE(ctx->HasOutput("MaxIndex"),
"Output(MaxIndex) of SequencePoolOp should not be null.");
ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X"));
}
}
};
......@@ -35,13 +40,17 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
SequencePoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor), the variable-length input of SequencePoolOp");
AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp");
AddOutput("Out",
"(Tensor), output of SequencePoolOp, which does not contain LoD "
"(Tensor) The output of SequencePoolOp does not contain LoD "
"infomation.");
AddOutput("MaxIndex",
"(Tensor<int>) This tensor is used for the max-pooling "
"of sequence to record the max indexes.")
.AsIntermediate();
AddAttr<std::string>(
"pooltype",
"(int, default AVERAGE) the pooling pooltype of SequencePoolOp.")
"(int, default AVERAGE) The pooling pooltype of SequencePoolOp.")
.SetDefault("AVERAGE");
AddComment(R"DOC(
SequencePoolOp pools features of all time-steps of each instance.
......@@ -92,6 +101,12 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
};
} // namespace operators
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_pooling.h"
namespace paddle {
namespace operators {
......@@ -34,7 +35,7 @@ class SequencePoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
auto* out = context.Output<Tensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype");
auto dims = in->dims();
......@@ -53,6 +54,16 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto lod_level_0 = lod[0];
out->mutable_data<T>(context.GetPlace());
if (pooltype == "MAX") {
math::MaxSeqPoolFunctor<Place, T> max_pool;
auto* index = context.Output<Tensor>("MaxIndex");
index->Resize({dims});
index->mutable_data<int>(context.GetPlace());
max_pool(context.device_context(), *in, out, index);
return;
}
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor in_t = in->Slice(static_cast<int>(lod_level_0[i]),
......@@ -69,8 +80,6 @@ class SequencePoolKernel : public framework::OpKernel<T> {
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "MAX") {
out_e.device(place) = in_e.maximum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
......@@ -87,8 +96,8 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
std::string pooltype = context.Attr<std::string>("pooltype");
auto dims = in->dims();
......@@ -96,6 +105,14 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace());
if (pooltype == "MAX") {
math::MaxSeqPoolGradFunctor<Place, T> max_pool_grad;
auto* index = context.Input<Tensor>("MaxIndex");
max_pool_grad(context.device_context(), *out_g, *index, in_g);
return;
}
if (pooltype == "LAST" || pooltype == "FIRST") {
// set X@Grad be zero at first when pooltype is LAST/FIRST
math::SetConstant<Place, T> functor;
......@@ -118,20 +135,6 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "MAX") {
auto in_t =
in->Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
in_t_map(in_t.data<T>(), h, w);
int row_id;
Eigen::array<int, 2> extents{{1, 1}};
for (int col_id = 0; col_id < w; col_id++) {
in_t_map.col(col_id).maxCoeff(&row_id);
Eigen::array<int, 2> in_offsets{{row_id, col_id}};
Eigen::array<int, 2> out_offsets{{0, col_id}};
in_g_e.slice(in_offsets, extents).device(place) =
out_g_e.slice(out_offsets, extents);
}
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e;
} else if (pooltype == "FIRST") {
......
......@@ -29,6 +29,9 @@ class TestSeqAvgPool(OpTest):
self.check_output()
def test_check_grad(self):
# Remove MaxIndex after check_grad is refined.
self.outputs['MaxIndex'] = \
np.zeros(self.outputs['Out'].shape).astype('int32')
self.check_grad(["X"], "Out")
......@@ -85,31 +88,53 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(sub_x.sum(axis=0) / np.sqrt(len), (3, 17))
def test_check_grad(self):
# Remove MaxIndex after check_grad is refined.
self.outputs['MaxIndex'] = \
np.zeros(self.outputs['Out'].shape).astype('int32')
self.check_grad(["X"], "Out", max_relative_error=0.06)
class TestSeqMaxPool(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
for i in range(4):
l = lod[0][i + 1] - lod[0][i]
x[lod[0][i] + np.random.randint(l), :] += 2.0
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
return x, lod, out
def compute(self, x, lod, out):
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
def test_check_grad(self):
# Remove MaxPool2D from gradient check to confirm the success of CI.
return
class TestSeqMaxPool2D(TestSeqAvgPool2D):
def set_data(self):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 3, 11]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
self.inputs = {'X': (x, lod)}
for i in range(4):
l = lod[0][i + 1] - lod[0][i]
x[lod[0][i] + np.random.randint(l), :] += 1.0
out = np.zeros((4, 3, 11)).astype('float32')
self.outputs = {'Out': out}
return x, lod, out
def compute(self, x, lod, out):
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 17))
def test_check_grad(self):
# Remove MaxPool2D from gradient check to confirm the success of CI.
return
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 11))
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
class TestSeqLastPool(TestSeqAvgPool):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册