提交 c2575ac7 编写于 作者: S shippingwang

update API.spec, test=develop

上级 ebeee930
此差异已折叠。
...@@ -27,8 +27,6 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { ...@@ -27,8 +27,6 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
// ENFORCE group
ctx->SetOutputDim("Out", input_dims); ctx->SetOutputDim("Out", input_dims);
} }
/* /*
...@@ -60,11 +58,11 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -60,11 +58,11 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Shuffle Channel operator Shuffle Channel operator
This operator obtains the group convolutional layer with channels shuffled. This operator obtains the group convolutional layer with channels shuffled.
First, divide the input channels in each group into several subgroups, Firstly, divide the input channels in each group into several subgroups,
then, feed each group in the next layer with different subgroups. then, feed each group in the next layer with different subgroups.
According to the paper, "Suppose a convolution layer with g groups According to the paper, "Suppose a convolution layer with G groups
whose output has g * n channels, first reshape the output channel dimension into(g,n), whose output has (G * N) channels, first reshape the output channel dimension into(G,N),
transposing and then flattening it back as the input of next layer. " transposing and then flattening it back as the input of next layer. "
Shuffle channel operation makes it possible to build more powerful structures Shuffle channel operation makes it possible to build more powerful structures
...@@ -89,8 +87,6 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { ...@@ -89,8 +87,6 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
// ENFORCE group
ctx->SetOutputDim(framework::GradVarName("X"), input_dims); ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
} }
/* /*
...@@ -112,7 +108,6 @@ namespace ops = paddle::operators; ...@@ -112,7 +108,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp, REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp,
ops::ShuffleChannelOpMaker, ops::ShuffleChannelOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
// paddle::framework::EmptyGradOpMaker);
REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp); REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -26,7 +26,6 @@ static inline int NumBlocks(const int N) { ...@@ -26,7 +26,6 @@ static inline int NumBlocks(const int N) {
} }
template <typename T> template <typename T>
__global__ void ShuffleChannel(const int nthreads, const int feature_map_size, __global__ void ShuffleChannel(const int nthreads, const int feature_map_size,
T* output, const T* input, int group_row, T* output, const T* input, int group_row,
int group_column, int len) { int group_column, int len) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -76,7 +76,6 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> { ...@@ -76,7 +76,6 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad = auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X")); ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册