提交 5802880b 编写于 作者: W wanghaox

update maxoutop for code review 3

上级 3ef776ef
...@@ -22,23 +22,20 @@ namespace math { ...@@ -22,23 +22,20 @@ namespace math {
* All tensors are in NCHW format. * All tensors are in NCHW format.
* groups mustbe > 1 * groups mustbe > 1
*/ */
template <typename MaxOutProcess, typename T> template <typename T>
class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> { class MaxOutFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
framework::Tensor * output, framework::Tensor * output,
int groups, int groups) {
MaxOutProcess maxout_process) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = output->dims()[1]; const int output_channels = output->dims()[1];
int fea_size = input_height * input_width; int fea_size = input_height * input_width;
// c_size mean output one batch size // c_size means the output size of each sample
int c_size = fea_size * output_channels; int c_size = fea_size * output_channels;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
...@@ -47,10 +44,11 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> { ...@@ -47,10 +44,11 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c; int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) { for (int f = 0; f < fea_size; ++f) {
T ele = maxout_process.initial(); // T ele = maxout_process.initial();
T ele = static_cast<T>(-FLT_MAX);
for (int ph = 0; ph < groups; ++ph) { for (int ph = 0; ph < groups; ++ph) {
maxout_process.compute(ele, T x=input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f];
input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); ele = ele > x ? ele : x;
} }
output_data[(new_bindex+new_cindex+f)] = ele; output_data[(new_bindex+new_cindex+f)] = ele;
} }
...@@ -74,9 +72,7 @@ public: ...@@ -74,9 +72,7 @@ public:
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = output.dims()[1]; const int output_channels = output.dims()[1];
int fea_size = input_height * input_width; int fea_size = input_height * input_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
...@@ -87,15 +83,15 @@ public: ...@@ -87,15 +83,15 @@ public:
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c; int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) { for (int f = 0; f < fea_size; ++f) {
int input_idx = 0; int input_idx0 = (blen + clen) * groups + f;
bool stop = false; bool continue_match = true;
int output_idx = blen + clen + f; int output_idx = blen + clen + f;
for (int g = 0; g < groups && !stop; ++g) { for (int g = 0; g < groups && continue_match; ++g) {
input_idx = (blen + clen) * groups + fea_size * g + f; int input_idx = input_idx0 + fea_size * g;
input_grad_data[input_idx] = 0; input_grad_data[input_idx] = 0;
if (input_data[input_idx] == output_data[output_idx]) { if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx]; input_grad_data[input_idx] += output_grad_data[output_idx];
stop = true; continue_match = false;
} }
} }
} }
...@@ -106,10 +102,8 @@ public: ...@@ -106,10 +102,8 @@ public:
template class MaxOutGradFunctor<platform::CPUPlace, float>; template class MaxOutGradFunctor<platform::CPUPlace, float>;
template class MaxOutGradFunctor<platform::CPUPlace, double>; template class MaxOutGradFunctor<platform::CPUPlace, double>;
template class MaxOutFunctor<platform::CPUPlace, template class MaxOutFunctor<platform::CPUPlace, float>;
math::MaxOut<float>, float>; template class MaxOutFunctor<platform::CPUPlace, double>;
template class MaxOutFunctor<platform::CPUPlace,
math::MaxOut<double>, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -19,27 +19,28 @@ namespace paddle { ...@@ -19,27 +19,28 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename MaxOutProcess, typename T> template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data, __global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels, const int channels,
const int input_height, const int input_width, const int input_height, const int input_width,
int groups, T* output_data, int groups, T* output_data ) {
MaxOutProcess maxout_process) {
const int size = input_height * input_width * channels / groups; const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width; const int feat_len = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; int index = blockIdx.x * blockDim.x + threadIdx.x;
index += blockDim.x * gridDim.x) { int offset = blockDim.x * gridDim.x;
int batch_idx = index / size; for (int i = index; i < nthreads; i += offset) {
int batch_offset = index % size; int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len; int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len; int feat_idx = batch_offset % feat_len;
int data_idx = int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = maxout_process.initial(); T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
maxout_process.compute(ele, input_data[data_idx + g * feat_len]); T x=input_data[data_idx + g * feat_len];
ele = ele > x ? ele : x;
} }
output_data[index] = ele; output_data[i] = ele;
} }
} }
template <typename T> template <typename T>
...@@ -49,38 +50,38 @@ __global__ void KernelMaxoutGrad( ...@@ -49,38 +50,38 @@ __global__ void KernelMaxoutGrad(
const int input_height, const int input_width, int groups) { const int input_height, const int input_width, int groups) {
const int size = input_height * input_width * channels / groups; const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width; const int feat_len = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; int index = blockIdx.x * blockDim.x + threadIdx.x;
index += blockDim.x * gridDim.x) { int offset = blockDim.x * gridDim.x;
int batch_idx = index / size; for (int i = index; i < nthreads; i += offset) {
int batch_offset = index % size; int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len; int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len; int feat_idx = batch_offset % feat_len;
int data_idx = int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int maxIndex = -1; int max_index = -1;
bool stop = false; bool continue_match = true;
for (int g = 0; g < groups && !stop; ++g) { for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[index]) { if (input_data[data_idx + g * feat_len] == output_data[i]) {
maxIndex = data_idx + g * feat_len; max_index = data_idx + g * feat_len;
stop = true; continue_match = false;
} }
} }
if (maxIndex != -1) { if (max_index != -1) {
// atomic add // atomic add
platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); platform::CudaAtomicAdd(input_grad + max_index, output_grad[index]);
} }
} }
} }
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW format.
*/ */
template <typename MaxOutProcess, typename T> template <typename T>
class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> { class MaxOutFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output, const framework::Tensor& input, framework::Tensor * output,
int groups, int groups) {
MaxOutProcess maxout_process) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[1];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
...@@ -97,12 +98,11 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> { ...@@ -97,12 +98,11 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxOut< KernelMaxOut<
MaxOutProcess,
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, input_channels, .stream()>>>(nthreads, input_data, input_channels,
input_height, input_width, groups, input_height, input_width, groups,
output_data, maxout_process); output_data);
} }
}; };
/* /*
...@@ -145,10 +145,8 @@ class MaxOutGradFunctor<platform::GPUPlace, T> { ...@@ -145,10 +145,8 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
template class MaxOutGradFunctor<platform::GPUPlace, float>; template class MaxOutGradFunctor<platform::GPUPlace, float>;
template class MaxOutGradFunctor<platform::GPUPlace, double>; template class MaxOutGradFunctor<platform::GPUPlace, double>;
template class MaxOutFunctor<platform::GPUPlace, template class MaxOutFunctor<platform::GPUPlace, float>;
math::MaxOut<float>, float>; template class MaxOutFunctor<platform::GPUPlace, double>;
template class MaxOutFunctor<platform::GPUPlace,
math::MaxOut<double>, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h" #include "paddle/platform/hostdevice.h"
...@@ -22,42 +21,18 @@ namespace paddle { ...@@ -22,42 +21,18 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#define FLT_MAX \ #define FLT_MAX \
__FLT_MAX__ __FLT_MAX__
/* template <typename Place, typename T>
* \brief Extracting simple operations from maxout.
* need "initial", "compute"
* operation.
*/
template <class T>
class MaxOut {
public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
};
template <class T>
class MaxOutGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) {
dx += dy * (x == y);
}
};
template <typename Place, typename MaxOutProcess, typename T>
class MaxOutFunctor { class MaxOutFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output, const framework::Tensor& input, framework::Tensor * output,
int groups, MaxOutProcess maxout_compute); int groups );
}; };
template <typename Place, class T> template <typename Place, class T>
class MaxOutGradFunctor { class MaxOutGradFunctor {
public: public:
...@@ -67,13 +42,6 @@ class MaxOutGradFunctor { ...@@ -67,13 +42,6 @@ class MaxOutGradFunctor {
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, int groups); const framework::Tensor& output_grad, int groups);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/operators/maxout_op.h" #include "paddle/operators/maxout_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,18 +32,18 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -33,18 +32,18 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
"width of feature."); "width of feature.");
AddAttr<int>( AddAttr<int>(
"groups", "groups",
R"DOC(The group number of input layer. R"DOC(The group number of input layer.
)DOC"); )DOC");
AddComment(R"DOC( AddComment(R"DOC(
- Input: NCHW. - Input: NCHW.
- Output: feature map size same as input. Channel is (input channel) / groups. - Output: The feature map size of output is the same as the input.
The output_channel is (input channel) / groups
So groups should be larger than 1, and the num of channels should be able So groups should be larger than 1, and the num of channels should be able
to devided by groups. to be devided by groups.
.. math:: math:
y_{si+j} = \max_k x_{gsi + sk + j} y_{si+j} = \max_k x_{gsi + sk + j}
g = groups g = groups
s = input.size / num_channels s = input.size / num_channels
...@@ -57,29 +56,6 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,29 +56,6 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
- Multi-digit Number Recognition from Street View \ - Multi-digit Number Recognition from Street View \
Imagery using Deep Convolutional Neural Networks: \ Imagery using Deep Convolutional Neural Networks: \
https://arxiv.org/pdf/1312.6082v4.pdf https://arxiv.org/pdf/1312.6082v4.pdf
The simple usage is:
.. code-block:: python
maxout = maxout_layer(input,
num_channels=128,
groups=4)
:param input: The input of this layer.
:type input: LayerOutput
:param num_channels: The channel number of input layer. If None will be set
automatically from previous output.
:type num_channels: int | None
:param groups: The group number of input layer.
:type groups: int
:param name: The name of this layer. It is optional.
:type name: None | basestring.
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
)DOC"); )DOC");
} }
}; };
...@@ -88,7 +64,6 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,7 +64,6 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
class MaxOutOp : public framework::OperatorWithKernel { class MaxOutOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp" PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp"
"should not be null."); "should not be null.");
...@@ -96,26 +71,20 @@ class MaxOutOp : public framework::OperatorWithKernel { ...@@ -96,26 +71,20 @@ class MaxOutOp : public framework::OperatorWithKernel {
"Output(Out) of maxoutOp should not be null."); "Output(Out) of maxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
// check groups > 1 // check groups > 1
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
groups, 1, groups, 1,
"in maxoutop groups should be larger than 1"); "groups should be larger than 1 in maxoutop");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups}); std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
output_shape.push_back(in_x_dims[2]); output_shape.push_back(in_x_dims[2]);
output_shape.push_back(in_x_dims[3]); output_shape.push_back(in_x_dims[3]);
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
} }
}; };
class MaxOutOpGrad : public framework::OperatorWithKernel { class MaxOutOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
...@@ -129,8 +98,6 @@ class MaxOutOpGrad : public framework::OperatorWithKernel { ...@@ -129,8 +98,6 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad, REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad,
ops::MaxOutOpGrad); ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::CPUPlace, REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::CPUPlace,
float>); float>);
REGISTER_OP_CPU_KERNEL(maxout_grad, REGISTER_OP_CPU_KERNEL(maxout_grad,
......
...@@ -29,16 +29,12 @@ class MaxOutKernel : public framework::OpKernel<T> { ...@@ -29,16 +29,12 @@ class MaxOutKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X"); const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
paddle::operators::math::MaxOutFunctor< paddle::operators::math::MaxOutFunctor<
Place, paddle::operators::math::MaxOut<T>, T> Place, T>
maxout_forward; maxout_forward;
paddle::operators::math::MaxOut<T> maxout_process; maxout_forward(context.device_context(), *in_x, out, groups);
maxout_forward(context.device_context(), *in_x, out, groups,
maxout_process);
} }
}; };
...@@ -51,15 +47,12 @@ class MaxOutGradKernel : public framework::OpKernel<T> { ...@@ -51,15 +47,12 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
const Tensor* out_grad = const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out")); context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
auto& device_ctx = context.device_context(); auto& device_ctx = context.device_context();
math::SetConstant<Place, T> zero; math::SetConstant<Place, T> zero;
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0)); zero(device_ctx, in_x_grad, static_cast<T>(0.0));
paddle::operators::math::MaxOutGradFunctor<Place, T> paddle::operators::math::MaxOutGradFunctor<Place, T>
maxout_backward; maxout_backward;
maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, maxout_backward(context.device_context(), *in_x, *in_x_grad, *out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册