提交 5a9ea9a7 编写于 作者: S shippingwang

Add ShuffleChannel Op

上级 76c6f115
......@@ -28,7 +28,7 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
// ENFORCE group
// auto group = ctx->Attrs().Get<int>("group");
ctx->SetOutputDim("Out", input_dims);
}
/*
......@@ -87,6 +87,10 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
"Output(X@Grad) should not be null");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
// ENFORCE group
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
}
/*
......
......@@ -81,6 +81,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
......@@ -101,6 +102,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
int blocks = NumBlocks(output_grad->numel());
int threads = kNumCUDAThreads;
int count = num * group_column * group_row * sp_sz;
ShuffleChannel<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
count, feature_map_size, input_grad_data, output_grad_data, group_row,
......
......@@ -55,6 +55,8 @@ __all__ = [
'softmax',
'pool2d',
'pool3d',
'adaptive_pool2d',
'adaptive_pool3d',
'batch_norm',
'beam_search_decode',
'conv2d_transpose',
......@@ -9342,24 +9344,22 @@ def shuffle_channel(x, group=1, name=None):
x: The input tensor variable..
group: The num of group
Returns:
Variable: channel shuffled tensor variable.
Variable: channels shuffled tensor variable.
Raises:
ValueError: If group in not an int type variable.
ValueError: If group is not an int type variable.
Examples:
.. code-block:: python
out = fluid.layers.shuffle_channel(x=group_conv,group=4)
"""
helper = LayerHelper("shuffle_channel", **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.intput_dtype('x'))
dtype=helper.input_dtype('x'))
if not isinstance(group, int):
raise TypeError("group must be int type")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册