提交 b28b2f17 编写于 作者: Q QI JUN 提交者: Yu Yang

refine test_recognize_digits_mlp and format codes (#5937)

上级 d89ff5b6
......@@ -55,7 +55,7 @@ paddle_error paddle_matrix_set_row(paddle_matrix mat,
}
PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
paddle_real* value) {
paddle_real* value) {
if (mat == nullptr || value == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat);
if (ptr->mat == nullptr) return kPD_NULLPTR;
......@@ -75,7 +75,7 @@ PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
}
PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat,
paddle_real* result) {
paddle_real* result) {
if (mat == nullptr || result == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat);
if (ptr->mat == nullptr) return kPD_NULLPTR;
......
......@@ -79,7 +79,7 @@ PD_API paddle_error paddle_matrix_set_row(paddle_matrix mat,
* @note value should contain enough element of data to init the mat
*/
PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
paddle_real* value);
paddle_real* value);
/**
* @brief PDMatGetRow Get raw row buffer from matrix
......@@ -93,14 +93,14 @@ PD_API paddle_error paddle_matrix_get_row(paddle_matrix mat,
paddle_real** rawRowBuffer);
/**
* @brief copy data from the matrix
* @brief copy data from the matrix
* @param [in] mat Target matrix
* @param [out] result pointer to store the matrix data
* @param [out] result pointer to store the matrix data
* @return paddle_error
* @note the space of the result should allocated before invoke this API
*/
PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat,
paddle_real* result);
paddle_real* result);
/**
* @brief PDMatCreateNone Create None Matrix
* @return
......
......@@ -135,18 +135,17 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
auto dst_ptr = static_cast<void*>(dst->data());
if (platform::is_cpu_place(src.place())) {
memory::Copy(dst_place, dst_ptr, boost::get<platform::CPUPlace>(src.place()),
src_ptr, size);
memory::Copy(dst_place, dst_ptr,
boost::get<platform::CPUPlace>(src.place()), src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src.place())) { // NOLINT
memory::Copy(
dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()), src_ptr,
size,
dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()),
src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
} // namespace framework
......
......@@ -23,8 +23,7 @@ template <typename T>
class MaxOutFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * output,
const framework::Tensor& input, framework::Tensor* output,
int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
......@@ -37,34 +36,30 @@ class MaxOutFunctor<platform::CPUPlace, T> {
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups
+ ph * fea_size + f];
T x = input_data[(new_bindex + new_cindex) * groups +
ph * fea_size + f];
ele = ele > x ? ele : x;
}
output_data[(new_bindex+new_cindex+f)] = ele;
output_data[(new_bindex + new_cindex + f)] = ele;
}
}
}
}
};
template <class T>
class MaxOutGradFunctor<platform::CPUPlace, T> {
public:
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const framework::Tensor& output_grad, int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -84,11 +79,11 @@ public:
bool continue_match = true;
int output_idx = blen + clen + f;
for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
int input_idx = input_idx0 + fea_size * g;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
}
}
}
......
......@@ -21,9 +21,9 @@ namespace math {
template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels,
const int input_height, const int input_width,
int groups, T* output_data ) {
const int channels, const int input_height,
const int input_width, int groups,
T* output_data) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -34,7 +34,7 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) {
T x = input_data[data_idx + g * feat_len];
......@@ -44,34 +44,35 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
}
}
template <typename T>
__global__ void KernelMaxoutGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, int groups) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
__global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
const T* output_data, const T* output_grad,
T* input_grad, const int channels,
const int input_height, const int input_width,
int groups) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int max_index = -1;
bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
continue_match = false;
break;
}
}
if (max_index != -1) {
input_grad[max_index] += output_grad[index];
int max_index = -1;
bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
continue_match = false;
break;
}
}
if (max_index != -1) {
input_grad[max_index] += output_grad[index];
}
}
}
/*
* All tensors are in NCHW format.
......@@ -80,7 +81,7 @@ template <typename T>
class MaxOutFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output,
const framework::Tensor& input, framework::Tensor* output,
int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
......@@ -92,7 +93,7 @@ class MaxOutFunctor<platform::GPUPlace, T> {
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
......@@ -101,8 +102,7 @@ class MaxOutFunctor<platform::GPUPlace, T> {
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, input_channels,
input_height, input_width, groups,
output_data);
input_height, input_width, groups, output_data);
}
};
/*
......@@ -112,11 +112,9 @@ template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const framework::Tensor& output_grad, int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
......@@ -129,7 +127,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
......@@ -137,9 +135,9 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
KernelMaxoutGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups);
.stream()>>>(nthreads, input_data, output_data,
output_grad_data, input_grad_data, input_channels,
input_height, input_width, groups);
}
};
......
......@@ -21,15 +21,14 @@ namespace paddle {
namespace operators {
namespace math {
#define FLT_MAX \
__FLT_MAX__
#define FLT_MAX __FLT_MAX__
template <typename Place, typename T>
class MaxOutFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output,
const framework::Tensor& input, framework::Tensor* output,
int groups);
};
......@@ -37,8 +36,7 @@ template <typename Place, class T>
class MaxOutGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups);
};
......
......@@ -22,16 +22,17 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
AddInput(
"X",
"(Tensor) The input tensor of maxout operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"(Tensor) The output tensor of maxout operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
"(Tensor) The output tensor of maxout operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<int>(
"groups",
R"DOC("Specifies how many groups the input tensor will be split"
......@@ -59,21 +60,19 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class MaxOutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp"
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxoutOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
// check groups > 1
PADDLE_ENFORCE_GT(
groups, 1,
"groups should be larger than 1 in maxoutop");
PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
output_shape.push_back(in_x_dims[2]);
output_shape.push_back(in_x_dims[3]);
......@@ -87,18 +86,17 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad,
ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::CPUPlace,
float>);
REGISTER_OP_CPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::CPUPlace,
float>);
ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(maxout,
ops::MaxOutKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
maxout_grad, ops::MaxOutGradKernel<paddle::platform::CPUPlace, float>);
......@@ -18,8 +18,6 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(maxout,
ops::MaxOutKernel<paddle::platform::GPUPlace, float>,
ops::MaxOutKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
float>,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
double>);
REGISTER_OP_GPU_KERNEL(
maxout_grad, ops::MaxOutGradKernel<paddle::platform::GPUPlace, float>,
ops::MaxOutGradKernel<paddle::platform::GPUPlace, double>);
......@@ -53,7 +53,7 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
math::MaxOutGradFunctor<Place, T> maxout_backward;
maxout_backward(context.device_context(), *in_x, in_x_grad, *out,
*out_grad, groups);
*out_grad, groups);
}
}
};
......
......@@ -43,8 +43,8 @@ class ROIPoolOp : public framework::OperatorWithKernel {
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
PADDLE_ENFORCE(rois_dims[1] == kROISize,
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
......@@ -65,7 +65,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("Argmax", out_dims);
}
}
protected:
framework::OpKernelType GetKernelType(
......@@ -100,7 +100,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ROIPoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor), "
......@@ -125,21 +125,22 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), "
"Argmaxes corresponding to indices in X used "
"for gradient computation. Only output "
"if arg “is_test” is false.").AsIntermediate();
"if arg “is_test” is false.")
.AsIntermediate();
AddAttr<float>("spatial_scale",
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling.")
.SetDefault(1.0);
.SetDefault(1.0);
AddAttr<int>("pooled_height",
"(int, default 1), "
"The pooled output height.")
.SetDefault(1);
.SetDefault(1);
AddAttr<int>("pooled_width",
"(int, default 1), "
"The pooled output width.")
.SetDefault(1);
.SetDefault(1);
AddComment(R"DOC(
ROIPool operator
......@@ -153,11 +154,10 @@ https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, roi_pool_grad,
ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
roi_pool,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
roi_pool, ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
roi_pool_grad,
......
......@@ -29,101 +29,95 @@ static inline int NumBlocks(const int N) {
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void GPUROIPoolForward(
const int nthreads, const T* input_data, const int64_t* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
T* output_data, int64_t* argmax_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
int maxidx = -1;
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_data_index = h * width + w;
if (offset_input_data[input_data_index] > maxval) {
maxval = offset_input_data[input_data_index];
maxidx = input_data_index;
}
template <typename T>
__global__ void GPUROIPoolForward(const int nthreads, const T* input_data,
const int64_t* input_rois,
const float spatial_scale, const int channels,
const int height, const int width,
const int pooled_height,
const int pooled_width, T* output_data,
int64_t* argmax_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
int maxidx = -1;
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_data_index = h * width + w;
if (offset_input_data[input_data_index] > maxval) {
maxval = offset_input_data[input_data_index];
maxidx = input_data_index;
}
}
output_data[index] = maxval;
if (argmax_data) {
argmax_data[index] = maxidx;
}
}
output_data[index] = maxval;
if (argmax_data) {
argmax_data[index] = maxidx;
}
}
}
template <typename T>
__global__ void GPUROIPoolBackward(
const int nthreads,
const int64_t* input_rois,
const T* output_grad,
const int64_t* argmax_data,
const int num_rois,
const float spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
T* offset_input_grad = input_grad + input_offset;
const int64_t* offset_argmax_data = argmax_data + output_offset;
int argmax = offset_argmax_data[ph * pooled_width + pw];
if (argmax != -1) {
platform::CudaAtomicAdd(offset_input_grad + argmax,
const int nthreads, const int64_t* input_rois, const T* output_grad,
const int64_t* argmax_data, const int num_rois, const float spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
T* offset_input_grad = input_grad + input_offset;
const int64_t* offset_argmax_data = argmax_data + output_offset;
int argmax = offset_argmax_data[ph * pooled_width + pw];
if (argmax != -1) {
platform::CudaAtomicAdd(
offset_input_grad + argmax,
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
}
}
}
}
template <typename Place, typename T>
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
......@@ -145,25 +139,18 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
int width = in_dims[3];
size_t rois_num = rois->dims()[0];
if (rois_num== 0) return;
if (rois_num == 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
GPUROIPoolForward<T>
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size,
in->data<T>(),
rois->data<int64_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
GPUROIPoolForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale,
channels, height, width, pooled_height, pooled_width,
out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
}
};
......@@ -175,10 +162,8 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto* rois = ctx.Input<Tensor>("ROIs");
auto* argmax = ctx.Input<Tensor>("Argmax");
auto* out_grad =
ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<Tensor>(framework::GradVarName("X"));
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
......@@ -199,21 +184,13 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIPoolBackward<T>
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size,
rois->data<int64_t>(),
out_grad->data<T>(),
argmax->data<int64_t>(),
rois_num,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
x_grad->mutable_data<T>(ctx.GetPlace()));
}
GPUROIPoolBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<int64_t>(), out_grad->data<T>(),
argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
width, pooled_height, pooled_width,
x_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
};
......@@ -223,8 +200,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
roi_pool,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>,
roi_pool, ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
roi_pool_grad,
......
......@@ -136,8 +136,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
......
......@@ -45,7 +45,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
// Initialize the output's dims to maximum,
// and re-set to real dims by the value of Offset and Length at kernel
ctx->SetOutputDim("Out", input_dims);
}
}
protected:
framework::OpKernelType GetKernelType(
......@@ -93,8 +93,7 @@ class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), "
"a vector<int> to describe the length of every input sequence for "
"sub sequence item.");
AddOutput("Out",
"(LoDTensor), the output of SequenceSliceOp.");
AddOutput("Out", "(LoDTensor), the output of SequenceSliceOp.");
AddComment(R"DOC(
Sequence slice operator
......
......@@ -38,6 +38,7 @@ UCI_TEST_DATA = None
URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar'
MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b'
def feature_range(maximums, minimums):
import matplotlib
matplotlib.use('Agg')
......@@ -114,7 +115,8 @@ def test():
def model():
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar', MD5_MODEL)
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar',
MD5_MODEL)
with open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f)
return parameters
......
......@@ -35,6 +35,13 @@ opts = optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
inference_program = fluid.default_main_program().clone()
test_accuracy = fluid.evaluator.Accuracy(
input=predict, label=label, main_program=inference_program)
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
inference_program = fluid.io.get_inference_program(
test_target, main_program=inference_program)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
......@@ -69,11 +76,6 @@ for pass_id in range(PASS_NUM):
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)
test_accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
inference_program = fluid.io.get_inference_program(test_target)
test_accuracy.reset(exe)
for data in test_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
......
......@@ -30,9 +30,7 @@ class TestMaxOutOp(OpTest):
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2]
self.groups=2
self.groups = 2
if __name__ == '__main__':
......
......@@ -4,24 +4,22 @@ import math
import sys
from op_test import OpTest
class TestROIPoolOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.calc_roi_pool()
self.inputs = {
'X': self.x,
'ROIs': self.rois}
self.inputs = {'X': self.x, 'ROIs': self.rois}
self.attrs = {
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width}
'pooled_width': self.pooled_width
}
self.outputs = {
'Out': self.outs,
'Argmax': self.argmaxes}
self.outputs = {'Out': self.outs, 'Argmax': self.argmaxes}
def init_test_case(self):
self.batch_size = 5
......@@ -30,10 +28,9 @@ class TestROIPoolOp(OpTest):
self.width = 4
# n, c, h, w
self.x_dim = (self.batch_size, self.channels,
self.height, self.width)
self.x_dim = (self.batch_size, self.channels, self.height, self.width)
self.spatial_scale = 1.0/4.0
self.spatial_scale = 1.0 / 4.0
self.pooled_height = 2
self.pooled_width = 2
self.rois_num = 2
......@@ -41,13 +38,11 @@ class TestROIPoolOp(OpTest):
self.x = np.random.random(self.x_dim).astype('float32')
def calc_roi_pool(self):
out_data = np.zeros(
(self.rois_num, self.channels,
self.pooled_height, self.pooled_width))
argmax_data = np.zeros(
(self.rois_num, self.channels,
self.pooled_height, self.pooled_width))
out_data = np.zeros((self.rois_num, self.channels, self.pooled_height,
self.pooled_width))
argmax_data = np.zeros((self.rois_num, self.channels,
self.pooled_height, self.pooled_width))
for i in range(self.rois_num):
roi = self.rois[i]
roi_batch_id = roi[0]
......@@ -56,8 +51,8 @@ class TestROIPoolOp(OpTest):
roi_end_w = int(round(roi[3] * self.spatial_scale))
roi_end_h = int(round(roi[4] * self.spatial_scale))
roi_height = int(max(roi_end_h - roi_start_h + 1, 1));
roi_width = int(max(roi_end_w - roi_start_w + 1, 1));
roi_height = int(max(roi_end_h - roi_start_h + 1, 1))
roi_width = int(max(roi_end_w - roi_start_w + 1, 1))
x_i = self.x[roi_batch_id]
......@@ -84,7 +79,7 @@ class TestROIPoolOp(OpTest):
out_data[i, c, ph, pw] = -sys.float_info.max
argmax_data[i, c, ph, pw] = -1
for h in range(hstart, hend):
for w in range(wstart, wend):
if x_i[c, h, w] > out_data[i, c, ph, pw]:
......@@ -104,11 +99,11 @@ class TestROIPoolOp(OpTest):
y1 = np.random.random_integers(
0, self.height / self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(
x1 + self.pooled_width, self.width / self.spatial_scale)
y2 = np.random.random_integers(
y1 + self.pooled_height, self.height / self.spatial_scale)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width / self.spatial_scale)
y2 = np.random.random_integers(y1 + self.pooled_height,
self.height / self.spatial_scale)
roi = [batch_ids[i], x1, y1, x2, y2]
rois.append(roi)
self.rois = np.array(rois).astype("int64")
......@@ -123,5 +118,6 @@ class TestROIPoolOp(OpTest):
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册