提交 5a9ea9a7 编写于 作者: S shippingwang

Add ShuffleChannel Op

上级 76c6f115
...@@ -28,7 +28,7 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { ...@@ -28,7 +28,7 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
// ENFORCE group // ENFORCE group
// auto group = ctx->Attrs().Get<int>("group");
ctx->SetOutputDim("Out", input_dims); ctx->SetOutputDim("Out", input_dims);
} }
/* /*
...@@ -87,6 +87,10 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { ...@@ -87,6 +87,10 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
"Output(X@Grad) should not be null"); "Output(X@Grad) should not be null");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
// ENFORCE group
ctx->SetOutputDim(framework::GradVarName("X"), input_dims); ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
} }
/* /*
......
...@@ -81,6 +81,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -81,6 +81,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X"); auto* input = ctx.Input<framework::Tensor>("X");
int group = ctx.Attr<int>("group"); int group = ctx.Attr<int>("group");
auto input_dims = input->dims(); auto input_dims = input->dims();
auto num = input_dims[0]; auto num = input_dims[0];
auto channel = input_dims[1]; auto channel = input_dims[1];
...@@ -101,6 +102,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -101,6 +102,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
int blocks = NumBlocks(output_grad->numel()); int blocks = NumBlocks(output_grad->numel());
int threads = kNumCUDAThreads; int threads = kNumCUDAThreads;
int count = num * group_column * group_row * sp_sz; int count = num * group_column * group_row * sp_sz;
ShuffleChannel< ShuffleChannel<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
count, feature_map_size, input_grad_data, output_grad_data, group_row, count, feature_map_size, input_grad_data, output_grad_data, group_row,
......
...@@ -55,6 +55,8 @@ __all__ = [ ...@@ -55,6 +55,8 @@ __all__ = [
'softmax', 'softmax',
'pool2d', 'pool2d',
'pool3d', 'pool3d',
'adaptive_pool2d',
'adaptive_pool3d',
'batch_norm', 'batch_norm',
'beam_search_decode', 'beam_search_decode',
'conv2d_transpose', 'conv2d_transpose',
...@@ -9342,24 +9344,22 @@ def shuffle_channel(x, group=1, name=None): ...@@ -9342,24 +9344,22 @@ def shuffle_channel(x, group=1, name=None):
x: The input tensor variable.. x: The input tensor variable..
group: The num of group group: The num of group
Returns: Returns:
Variable: channel shuffled tensor variable. Variable: channels shuffled tensor variable.
Raises: Raises:
ValueError: If group in not an int type variable. ValueError: If group is not an int type variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
out = fluid.layers.shuffle_channel(x=group_conv,group=4) out = fluid.layers.shuffle_channel(x=group_conv,group=4)
""" """
helper = LayerHelper("shuffle_channel", **locals()) helper = LayerHelper("shuffle_channel", **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=helper.intput_dtype('x')) dtype=helper.input_dtype('x'))
if not isinstance(group, int): if not isinstance(group, int):
raise TypeError("group must be int type") raise TypeError("group must be int type")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册