提交 16d4e137 编写于 作者: S shippingwang

Add ShuffleChannelOP

上级 7f73c16e
...@@ -19,26 +19,27 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { ...@@ -19,26 +19,27 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx - > HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ShuffleChannelOp should not be null."); "Input(X) of ShuffleChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ShuffleChannelOp should not be null."); "Output(Out) of ShuffleChannelOp should not be null.");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
// ENFORCE group // ENFORCE group
auto group = ctx->Attrs().Get<std::vector<int>>("group"); // auto group = ctx->Attrs().Get<int>("group");
ctx->SetOutputDim("Out", input_dims); ctx->SetOutputDim("Out", input_dims);
} }
/*
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace()); ctx.device_context());
} }
*/
}; };
class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -63,7 +64,7 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -63,7 +64,7 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
then, feed each group in the next layer with different subgroups. then, feed each group in the next layer with different subgroups.
According to the paper, "Suppose a convolution layer with g groups According to the paper, "Suppose a convolution layer with g groups
whose output has g x n channels, first reshape the output channel dimension into(g,n), whose output has g * n channels, first reshape the output channel dimension into(g,n),
transposing and then flattening it back as the input of next layer. " transposing and then flattening it back as the input of next layer. "
Shuffle channel operation makes it possible to build more powerful structures Shuffle channel operation makes it possible to build more powerful structures
...@@ -75,52 +76,49 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -75,52 +76,49 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
// Grad class ShuffleChannelGradOp : public framework::OperatorWithKernel {
class ShuffleChannelOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null") "Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) should not be null"); "Output(X@Grad) should not be null");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), input_dims); ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
} }
/*
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out")) framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
->type()), ctx.device_context());
ctx.device_context()); }
} */
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// how to write gpu kernal
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(shufflechannel, ops::ShuffleChannelOp, REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp,
ops::ShuffleChannelOpMaker, ops::ShuffleChannelOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
// paddle::framework::EmptyGradOpMaker); // paddle::framework::EmptyGradOpMaker);
REGISTER_OPERATOR(shufflechannel_grad, ops::ShuffleChannelGradOp); REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
shufflechannel, shuffle_channel,
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, float>, ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, double>); ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
shufflechannel_grad, shuffle_channel_grad,
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, float>, ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext,
double>); double>);
...@@ -10,15 +10,115 @@ See the License for the specific language governing permissions and ...@@ -10,15 +10,115 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/shuffle_channel_op.h" #include "paddle/fluid/operators/shuffle_channel_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void ShuffleChannel(const int nthreads, const int feature_map_size,
T* output, const T* input, int group_row,
int group_column, int len) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t ii = index; ii < nthreads; ii += offset) {
const int n = index / group_row / group_column / len;
const int i = (index / group_column / len) % group_row;
const int j = index / len % group_column;
const int k = index - (n * feature_map_size + (i * group_column + j) * len);
T* p_o = output + n * feature_map_size + (j * group_row + i) * len;
p_o[k] = input[index];
}
}
template <typename DeviceContext, typename T>
class ShuffleChannelOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out");
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
auto height = input_dims[2];
auto weight = input_dims[3];
auto feature_map_size = channel * height * weight;
auto sp_sz = height * weight;
int group_row = group;
int group_column = channel / group_row;
// count is the product of NCHW same as numel()
int count = num * group_column * group_row * sp_sz;
int blocks = NumBlocks(output->numel());
int threads = kNumCUDAThreads;
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
ShuffleChannel<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
count, feature_map_size, output_data, input_data, group_row,
group_column, sp_sz);
}
};
template <typename DeviceContext, typename T>
class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
auto height = input_dims[2];
auto weight = input_dims[3];
auto feature_map_size = channel * height * weight;
auto sp_sz = height * weight;
int group_row = group;
int group_column = channel / group_row;
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>();
int blocks = NumBlocks(output_grad->numel());
int threads = kNumCUDAThreads;
int count = num * group_column * group_row * sp_sz;
ShuffleChannel<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
count, feature_map_size, input_grad_data, output_grad_data, group_row,
group_column, sp_sz);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
shufflechannel, shuffle_channel,
ops::ShuffleChannelOpKernel<paddle::platform::CUDADeviceContext, float> ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::ShuffleChannelOpKernel<paddle::platform::CUDADeviceContext, ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext,
double>); double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
shufflechannel_grad, shuffle_channel_grad,
ops::ShuffleChannelOpGradKernel<paddle::platform::CUDADeviceContext, float> ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext,
ops::ShuffleChannelOpGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext,
double>); double>);
...@@ -21,10 +21,10 @@ namespace operators { ...@@ -21,10 +21,10 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ShuffleChannelOpKernel : public framework::OpKernel<T> { class ShuffleChannelOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X"); auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out"); auto* output = ctx.Output<framework::Tensor>("Out");
auto group = ctx.Input<framework::Tensor>("group"); int group = ctx.Attr<int>("group");
auto input_dims = input->dims(); auto input_dims = input->dims();
auto num = input_dims[0]; auto num = input_dims[0];
...@@ -34,21 +34,19 @@ class ShuffleChannelOpKernel : public framework::OpKernel<T> { ...@@ -34,21 +34,19 @@ class ShuffleChannelOpKernel : public framework::OpKernel<T> {
auto feature_map_size = channel * height * weight; auto feature_map_size = channel * height * weight;
auto sp_sz = height * weight; auto sp_sz = height * weight;
int group_row = group; int group_row = group;
int group_column = channels / group_row; int group_column = channel / group_row;
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = out->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
output_data_temp = output_data + n * feature_map_size;
input_data_temp = input_data + n * feature_map_size;
for (int i = 0; i < group_row; ++i) { for (int i = 0; i < group_row; ++i) {
for (int j = 0; j < group_column; ++j) { for (int j = 0; j < group_column; ++j) {
const auto* p_i = input_data_temp + (i * group_column + j) * sp_sz; const T* p_i = input_data + n * feature_map_size +
auto* p_o = output_data_temp + (j * group_row + i) * sp_sz; (i * group_column + j) * sp_sz;
memcpy(p_o, p_i, sizeof(Dtype) * sp_sz); T* p_o =
output_data + n * feature_map_size + (j * group_row + i) * sp_sz;
memcpy(p_o, p_i, sizeof(int) * sp_sz);
} }
} }
} }
...@@ -61,7 +59,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> { ...@@ -61,7 +59,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X"); auto* input = ctx.Input<framework::Tensor>("X");
auto group = ctx.Input<framework::Tensor>("group"); int group = ctx.Attr<int>("group");
auto input_dims = input->dims(); auto input_dims = input->dims();
auto num = input_dims[0]; auto num = input_dims[0];
...@@ -72,7 +70,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> { ...@@ -72,7 +70,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
auto sp_sz = height * weight; auto sp_sz = height * weight;
int group_row = group; int group_row = group;
int group_column = channels / group_row; int group_column = channel / group_row;
auto* output_grad = auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
...@@ -81,19 +79,17 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> { ...@@ -81,19 +79,17 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
output_grad_temp = output_grad_data + n * feature_map_size;
input_grad_temp = input_grad_data + n * feature_map_size;
for (int i = 0; i < group_row; ++i) { for (int i = 0; i < group_row; ++i) {
for (int j = 0; j < group_column; ++j) { for (int j = 0; j < group_column; ++j) {
const auto* p_i = output_grad_temp + (i * group_column + j) * sp_sz; const T* p_i = output_grad_data + n * feature_map_size +
auto* p_o = input_grad_temp + (j * group_row + i) * sp_sz; (i * group_column + j) * sp_sz;
memcpy(p_o, p_i, sizeof(Dtype) * sp_sz); T* p_o = input_grad_data + n * feature_map_size +
(j * group_row + i) * sp_sz;
memcpy(p_o, p_i, sizeof(int) * sp_sz);
} }
} }
} }
return;
} }
}; };
......
...@@ -173,7 +173,7 @@ __all__ = [ ...@@ -173,7 +173,7 @@ __all__ = [
'merge_selected_rows', 'merge_selected_rows',
'get_tensor_from_selected_rows', 'get_tensor_from_selected_rows',
'lstm', 'lstm',
'shufflechannel', 'shuffle_channel',
'psroi_pool', 'psroi_pool',
] ]
...@@ -9334,17 +9334,20 @@ def shuffle_channel(x, group=1, name=None): ...@@ -9334,17 +9334,20 @@ def shuffle_channel(x, group=1, name=None):
with multiple group convolutional layers. with multiple group convolutional layers.
Args: Args:
x: The input tensor variable. x: The input tensor variable..
group: The num of group
Returns: Returns:
Variable: channel shuffled tensor variable. Variable: channel shuffled tensor variable.
Raises: Raises:
ValueError: If group in not a int type variable. ValueError: If group in not an int type variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
out = fluid.layers.shuffle_channel(x=group_conv,group=4)
""" """
...@@ -9361,6 +9364,7 @@ def shuffle_channel(x, group=1, name=None): ...@@ -9361,6 +9364,7 @@ def shuffle_channel(x, group=1, name=None):
inputs={"X": x}, inputs={"X": x},
outputs={"Out": out}, outputs={"Out": out},
attrs={"group": group}) attrs={"group": group})
return out
@templatedoc() @templatedoc()
......
...@@ -1018,7 +1018,7 @@ class TestBook(unittest.TestCase): ...@@ -1018,7 +1018,7 @@ class TestBook(unittest.TestCase):
def test_shuffle_channel(self): def test_shuffle_channel(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.data(name="x", shape=[10, 32, 16, 16], dtype="float32") x = layers.data(name="x", shape=[1, 4, 2, 2], dtype="float32")
group = layers.data(name="group", shape=[1], dtype="int32") group = layers.data(name="group", shape=[1], dtype="int32")
out = layers.shuffle_channel(x, group) out = layers.shuffle_channel(x, group)
self.assertIsNotNone(out) self.assertIsNotNone(out)
......
...@@ -23,31 +23,29 @@ import paddle.fluid.core as core ...@@ -23,31 +23,29 @@ import paddle.fluid.core as core
class TestShuffleChannelOp(OpTest): class TestShuffleChannelOp(OpTest):
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'output')
def setUp(self): def setUp(self):
self.op_type = "shuffle_channel" self.op_type = "shuffle_channel"
self.batch_size = 10 self.batch_size = 1
self.input_channels = 16 self.input_channels = 4
self.layer_h = 32 self.layer_h = 2
self.layer_w = 32 self.layer_w = 2
self.group = 4 self.group = 2
self.x = np.random.random( self.x = np.random.random(
(self.batch_size, self.input_channels, self.layer_h, self, (self.batch_size, self.input_channels, self.layer_h,
layer_w)).astype('float32') self.layer_w)).astype('float32')
self.inputs = {'X': self.x} self.inputs = {'X': self.x}
self.attrs = {'group': self.group} self.attrs = {'group': self.group}
n, c, h, w = self.x.shape n, c, h, w = self.x.shape
input_reshaped = np.reshape(self.x, input_reshaped = np.reshape(self.x,
(-1, self.group, c // self.group, h, w)) (-1, self.group, c // self.group, h, w))
input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4)) input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4))
self.outputs = np.reshape(input_transposed, (-1, c, h, w)) self.outputs = {'Out': np.reshape(input_transposed, (-1, c, h, w))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册