提交 76fc1a82 编写于 作者: W wanghaox

for code review 4

上级 52f2366d
...@@ -18,10 +18,7 @@ namespace paddle { ...@@ -18,10 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
/* // All tensors are in NCHW format, and the groups must be greater than 1
* All tensors are in NCHW format.
* groups mustbe > 1
*/
template <typename T> template <typename T>
class MaxOutFunctor<platform::CPUPlace, T> { class MaxOutFunctor<platform::CPUPlace, T> {
public: public:
...@@ -44,7 +41,6 @@ class MaxOutFunctor<platform::CPUPlace, T> { ...@@ -44,7 +41,6 @@ class MaxOutFunctor<platform::CPUPlace, T> {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c; int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) { for (int f = 0; f < fea_size; ++f) {
// T ele = maxout_process.initial();
T ele = static_cast<T>(-FLT_MAX); T ele = static_cast<T>(-FLT_MAX);
for (int ph = 0; ph < groups; ++ph) { for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups T x = input_data[(new_bindex + new_cindex) * groups
...@@ -65,7 +61,7 @@ class MaxOutGradFunctor<platform::CPUPlace, T> { ...@@ -65,7 +61,7 @@ class MaxOutGradFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
framework::Tensor& input_grad, framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
int groups) { int groups) {
...@@ -77,7 +73,7 @@ public: ...@@ -77,7 +73,7 @@ public:
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i; int blen = fea_size * output_channels * i;
......
...@@ -112,7 +112,8 @@ template <typename T> ...@@ -112,7 +112,8 @@ template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> { class MaxOutGradFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
int groups) { int groups) {
...@@ -127,7 +128,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> { ...@@ -127,7 +128,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel(); int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
......
...@@ -38,7 +38,7 @@ class MaxOutGradFunctor { ...@@ -38,7 +38,7 @@ class MaxOutGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
framework::Tensor& input_grad, framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, int groups); const framework::Tensor& output_grad, int groups);
}; };
......
...@@ -34,14 +34,13 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -34,14 +34,13 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
"width of feature."); "width of feature.");
AddAttr<int>( AddAttr<int>(
"groups", "groups",
R"DOC(The group number of input layer. R"DOC("Specifies how many groups the input tensor will be split"
"in the channel dimension. And the number of output channel is "
"the number of channels divided by groups.."
)DOC"); )DOC");
AddComment(R"DOC( AddComment(R"DOC(
- Input: NCHW. Assumed the input shape is (N, Ci, H, W).
- Output: The feature map size of output is the same as the input. The output shape is (N, Co, H, W). Then `Co = Ci / groups`.
The output_channel is (input channel) / groups
So groups should be larger than 1, and the num of channels should be able
to be devided by groups.
math: math:
y_{si+j} = \max_k x_{gsi + sk + j} y_{si+j} = \max_k x_{gsi + sk + j}
...@@ -65,10 +64,10 @@ class MaxOutOp : public framework::OperatorWithKernel { ...@@ -65,10 +64,10 @@ class MaxOutOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp" PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp"
"should not be null."); "should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of maxoutOp should not be null."); "Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
// check groups > 1 // check groups > 1
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/maxout_op.h" #include "paddle/operators/maxout_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -31,9 +31,7 @@ class MaxOutKernel : public framework::OpKernel<T> { ...@@ -31,9 +31,7 @@ class MaxOutKernel : public framework::OpKernel<T> {
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
paddle::operators::math::MaxOutFunctor< math::MaxOutFunctor<Place, T> maxout_forward;
Place, T>
maxout_forward;
maxout_forward(context.device_context(), *in_x, out, groups); maxout_forward(context.device_context(), *in_x, out, groups);
} }
}; };
...@@ -53,10 +51,9 @@ class MaxOutGradKernel : public framework::OpKernel<T> { ...@@ -53,10 +51,9 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0)); zero(device_ctx, in_x_grad, static_cast<T>(0.0));
paddle::operators::math::MaxOutGradFunctor<Place, T> math::MaxOutGradFunctor<Place, T> maxout_backward;
maxout_backward; maxout_backward(context.device_context(), *in_x, in_x_grad, *out,
maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, *out_grad, groups);
*out_grad, groups);
} }
} }
}; };
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def maxout_forward_naive(input, groups,num_channels): def maxout_forward_naive(input, groups):
s0, s1, s2, s3 = input.shape s0, s1, s2, s3 = input.shape
return np.ndarray([s0, s1 / groups, groups, s2, s3], \ return np.ndarray([s0, s1 / groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2)) buffer = input, dtype=input.dtype).max(axis=(2))
...@@ -18,7 +18,7 @@ class TestMaxOutOp(OpTest): ...@@ -18,7 +18,7 @@ class TestMaxOutOp(OpTest):
self.num_channels).astype("float32") self.num_channels).astype("float32")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = {'groups': self.groups, 'num_channels': self.num_channels} self.attrs = {'groups': self.groups}
self.outputs = {'Out': output.astype('float32')} self.outputs = {'Out': output.astype('float32')}
...@@ -32,7 +32,6 @@ class TestMaxOutOp(OpTest): ...@@ -32,7 +32,6 @@ class TestMaxOutOp(OpTest):
self.MaxOut_forward_naive = maxout_forward_naive self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2] self.shape = [100, 6, 2, 2]
self.groups=2 self.groups=2
self.num_channels=6
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册