Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
91d2a57a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91d2a57a
编写于
7月 21, 2017
作者:
Z
Zhaolong Xing
提交者:
GitHub
7月 21, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2776 from NHZlX/mobilenet_gpu
Mobilenet gpu implementation
上级
e2880f16
6c528cbc
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
993 addition
and
22 deletion
+993
-22
paddle/function/ConvOpTest.cpp
paddle/function/ConvOpTest.cpp
+115
-19
paddle/function/DepthwiseConvOp.cpp
paddle/function/DepthwiseConvOp.cpp
+306
-0
paddle/function/DepthwiseConvOp.h
paddle/function/DepthwiseConvOp.h
+159
-0
paddle/function/DepthwiseConvOpGpu.cu
paddle/function/DepthwiseConvOpGpu.cu
+342
-0
paddle/gserver/layers/ExpandConvLayer.cpp
paddle/gserver/layers/ExpandConvLayer.cpp
+18
-3
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+49
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+4
-0
未找到文件。
paddle/function/ConvOpTest.cpp
浏览文件 @
91d2a57a
...
...
@@ -31,13 +31,22 @@ public:
ConvolutionTest
(
const
std
::
string
&
conv1
,
const
std
::
string
&
conv2
,
TestType
type
,
bool
useGroups
=
true
,
std
::
string
algo
=
"auto"
)
{
for
(
size_t
batchSize
:
{
1
,
32
})
{
for
(
size_t
inputSize
:
{
7
,
14
,
54
})
{
for
(
size_t
filterSize
:
{
1
,
3
,
5
})
{
for
(
size_t
inputChannels
:
{
3
,
64
})
{
for
(
size_t
outputChannels
:
{
3
,
64
,
128
})
{
if
(
inputChannels
<
outputChannels
)
break
;
for
(
size_t
outputChannels
:
{
3
,
64
})
{
if
(
inputChannels
>
outputChannels
)
break
;
size_t
groups
;
if
(
!
useGroups
)
{
groups
=
1
;
}
else
{
if
(
outputChannels
%
inputChannels
!=
0
)
continue
;
groups
=
inputChannels
;
}
for
(
size_t
stride
:
{
1
,
2
})
{
for
(
size_t
padding
:
{
0
,
1
})
{
if
(
padding
>=
filterSize
)
break
;
...
...
@@ -62,13 +71,24 @@ public:
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
1
)
.
set
(
"groups"
,
groups
)
.
set
(
"algo"
,
algo
));
TensorShape
input
{
batchSize
,
inputChannels
,
inputSize
,
inputSize
};
TensorShape
filter
{
outputChannels
,
inputChannels
,
filterSize
,
filterSize
};
TensorShape
filter
;
if
(
groups
>
1
)
filter
=
TensorShape
({
groups
,
outputChannels
/
groups
,
inputChannels
/
groups
,
filterSize
,
filterSize
});
else
filter
=
TensorShape
({
outputChannels
,
inputChannels
,
filterSize
,
filterSize
});
TensorShape
output
{
batchSize
,
outputChannels
,
outputSize
,
outputSize
};
...
...
@@ -85,7 +105,8 @@ public:
}
else
if
(
type
==
kBackwardFilterTest
)
{
test
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
output
));
test
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
input
));
test
.
addOutputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
filter
));
test
.
addOutputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
filter
),
ADD_TO
);
test
.
run
();
}
}
...
...
@@ -106,6 +127,7 @@ public:
ConvolutionTest2
(
const
std
::
string
&
conv1
,
const
std
::
string
&
conv2
,
TestType
type
,
bool
useGroups
=
true
,
std
::
string
algo
=
"auto"
)
{
for
(
size_t
batchSize
:
{
16
})
{
for
(
size_t
inputHeight
:
{
7
,
31
})
{
...
...
@@ -113,7 +135,15 @@ public:
for
(
size_t
filterHeight
:
{
1
,
5
})
{
for
(
size_t
filterWidth
:
{
3
,
7
})
{
for
(
size_t
inputChannels
:
{
7
})
{
for
(
size_t
outputChannels
:
{
32
})
{
for
(
size_t
outputChannels
:
{
7
})
{
size_t
groups
;
if
(
!
useGroups
)
{
groups
=
1
;
}
else
{
if
(
outputChannels
%
inputChannels
!=
0
)
continue
;
groups
=
inputChannels
;
}
size_t
stride
=
1
;
size_t
padding
=
0
;
size_t
outputHeight
=
...
...
@@ -141,13 +171,24 @@ public:
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
1
)
.
set
(
"groups"
,
groups
)
.
set
(
"algo"
,
algo
));
TensorShape
input
{
batchSize
,
inputChannels
,
inputHeight
,
inputWidth
};
TensorShape
filter
{
outputChannels
,
inputChannels
,
filterHeight
,
filterWidth
};
TensorShape
filter
;
if
(
groups
>
1
)
filter
=
TensorShape
({
groups
,
outputChannels
/
groups
,
inputChannels
/
groups
,
filterHeight
,
filterWidth
});
else
filter
=
TensorShape
({
outputChannels
,
inputChannels
,
filterHeight
,
filterWidth
});
TensorShape
output
{
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
};
...
...
@@ -164,7 +205,8 @@ public:
}
else
if
(
type
==
kBackwardFilterTest
)
{
test
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
output
));
test
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
input
));
test
.
addOutputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
filter
));
test
.
addOutputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
filter
),
ADD_TO
);
test
.
run
();
}
}
...
...
@@ -177,34 +219,88 @@ public:
}
};
// ======Start Convolution TEST======
TEST
(
Forward
,
GEMM
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_CPU
>
test
(
"NaiveConv-CPU"
,
"GemmConv-CPU"
,
kForwardTest
);
"NaiveConv-CPU"
,
"GemmConv-CPU"
,
kForwardTest
,
false
);
ConvolutionTest2
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_CPU
>
test2
(
"NaiveConv-CPU"
,
"GemmConv-CPU"
,
kForwardTest
);
"NaiveConv-CPU"
,
"GemmConv-CPU"
,
kForwardTest
,
false
);
}
#ifndef PADDLE_ONLY_CPU
TEST
(
Forward
,
GEMM2
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConv-CPU"
,
"GemmConv-GPU"
,
kForwardTest
);
"GemmConv-CPU"
,
"GemmConv-GPU"
,
kForwardTest
,
false
);
ConvolutionTest2
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test2
(
"GemmConv-CPU"
,
"GemmConv-GPU"
,
kForwardTest
);
"GemmConv-CPU"
,
"GemmConv-GPU"
,
kForwardTest
,
false
);
}
TEST
(
BackwardInput
,
GEMM
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConvGradInput-CPU"
,
"GemmConvGradInput-GPU"
,
kBackwardInputTest
);
"GemmConvGradInput-CPU"
,
"GemmConvGradInput-GPU"
,
kBackwardInputTest
,
false
);
ConvolutionTest2
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test2
(
"GemmConvGradInput-CPU"
,
"GemmConvGradInput-GPU"
,
kBackwardInputTest
);
"GemmConvGradInput-CPU"
,
"GemmConvGradInput-GPU"
,
kBackwardInputTest
,
false
);
}
TEST
(
BackwardFilter
,
GEMM
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConvGradFilter-CPU"
,
"GemmConvGradFilter-GPU"
,
kBackwardFilterTest
);
"GemmConvGradFilter-CPU"
,
"GemmConvGradFilter-GPU"
,
kBackwardFilterTest
,
false
);
ConvolutionTest2
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test2
(
"GemmConvGradFilter-CPU"
,
"GemmConvGradFilter-GPU"
,
kBackwardFilterTest
);
"GemmConvGradFilter-CPU"
,
"GemmConvGradFilter-GPU"
,
kBackwardFilterTest
,
false
);
}
#endif
// ======End Convolution TEST======
// ======Start DepthwiseConvolution TEST======
// TODO(zhaolong) The depthwise convolution cpu test will be added when the cpu
// version of depthwiseConv is implemented.
#ifndef PADDLE_ONLY_CPU
TEST
(
DepthwiseConvForward
,
GEMM2
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConv-CPU"
,
"DepthwiseConv-GPU"
,
kForwardTest
);
ConvolutionTest2
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test2
(
"GemmConv-CPU"
,
"DepthwiseConv-GPU"
,
kForwardTest
);
}
TEST
(
DepthwiseConvBackwardInput
,
GEMM
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConvGradInput-CPU"
,
"DepthwiseConvGradInput-GPU"
,
kBackwardInputTest
);
ConvolutionTest2
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test2
(
"GemmConvGradInput-CPU"
,
"DepthwiseConvGradInput-GPU"
,
kBackwardInputTest
);
}
TEST
(
DepthwiseConvBackwardFilter
,
GEMM
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConvGradFilter-CPU"
,
"DepthwiseConvGradFilter-GPU"
,
kBackwardFilterTest
);
ConvolutionTest2
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test2
(
"GemmConvGradFilter-CPU"
,
"DepthwiseConvGradFilter-GPU"
,
kBackwardFilterTest
);
}
#endif
// ======End DepthwiseConvolution TEST======
}
// namespace paddle
paddle/function/DepthwiseConvOp.cpp
0 → 100644
浏览文件 @
91d2a57a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DepthwiseConvOp.h"
#include "ConvOp.h"
#include "GemmFunctor.h"
namespace
paddle
{
template
<
class
T
>
class
DepthwiseConvFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
const
T
*
inputData
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
outputData
)
{
// TODO(zhaolong) : cpu implementation of depthwise convolution
}
};
template
<
class
T
>
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
const
T
*
outputGrad
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
inputGrad
)
{}
// TODO(zhaolong) : cpu implementation of depthwise convolution
};
template
<
class
T
>
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
const
T
*
outputGrad
,
const
T
*
inputData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
colData
,
T
*
filterGrad
)
{}
// TODO(zhaolong) : cpu implementation of depthwise convolution
};
/*
* \brief Forward calculation of depthwise convolution.
*/
template
<
DeviceType
Device
>
class
DepthwiseConvFunction
:
public
ConvFunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
ConvFunctionBase
::
init
(
config
);
}
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
input
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
output
=
outputs
[
0
].
shape
();
checkShape
(
input
,
filter
,
output
);
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
check
(
inputs
,
outputs
);
const
TensorShape
&
input
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
output
=
outputs
[
0
].
shape
();
size_t
batchSize
=
input
[
0
];
size_t
inputChannels
=
input
[
1
];
size_t
inputHeight
=
input
[
2
];
size_t
inputWidth
=
input
[
3
];
size_t
filterHeight
=
getFilterHeight
(
filter
);
size_t
filterWidth
=
getFilterWidth
(
filter
);
size_t
outputChannels
=
output
[
1
];
size_t
outputHeight
=
output
[
2
];
size_t
outputWidth
=
output
[
3
];
size_t
filterMultiplier
=
outputChannels
/
groups_
;
CHECK_EQ
(
inputChannels
,
groups_
);
real
*
inputData
=
inputs
[
0
].
data
<
real
>
();
real
*
filterData
=
inputs
[
1
].
data
<
real
>
();
real
*
outputData
=
outputs
[
0
].
data
<
real
>
();
DepthwiseConvFunctor
<
Device
,
real
>
depthwiseConv
;
depthwiseConv
(
inputData
,
filterData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputChannels
,
inputHeight
,
inputWidth
,
filterMultiplier
,
filterHeight
,
filterWidth
,
strideH
(),
strideW
(),
paddingH
(),
paddingW
(),
outputData
);
}
};
/*
* \brief Backward input calculation of depthwise convolution.
*/
template
<
DeviceType
Device
>
class
DepthwiseConvGradInputFunction
:
public
ConvFunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
ConvFunctionBase
::
init
(
config
);
}
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
input
=
outputs
[
0
].
shape
();
checkShape
(
input
,
filter
,
output
);
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
check
(
inputs
,
outputs
);
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
input
=
outputs
[
0
].
shape
();
size_t
batchSize
=
input
[
0
];
size_t
inputChannels
=
input
[
1
];
size_t
inputHeight
=
input
[
2
];
size_t
inputWidth
=
input
[
3
];
size_t
filterHeight
=
getFilterHeight
(
filter
);
size_t
filterWidth
=
getFilterWidth
(
filter
);
size_t
outputChannels
=
output
[
1
];
size_t
outputHeight
=
output
[
2
];
size_t
outputWidth
=
output
[
3
];
size_t
filterMultiplier
=
outputChannels
/
groups_
;
CHECK_EQ
(
inputChannels
,
groups_
);
real
*
outputGrad
=
inputs
[
0
].
data
<
real
>
();
real
*
filterData
=
inputs
[
1
].
data
<
real
>
();
real
*
inputGrad
=
outputs
[
0
].
data
<
real
>
();
DepthwiseConvGradInputFunctor
<
Device
,
real
>
depthwiseConvGradInput
;
depthwiseConvGradInput
(
outputGrad
,
filterData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputChannels
,
inputHeight
,
inputWidth
,
filterMultiplier
,
filterHeight
,
filterWidth
,
strideH
(),
strideW
(),
paddingH
(),
paddingW
(),
inputGrad
);
}
};
/*
* \brief Backward filter calculation of depthwise convolution.
*/
template
<
DeviceType
Device
>
class
DepthwiseConvGradFilterFunction
:
public
ConvFunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
ConvFunctionBase
::
init
(
config
);
}
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
input
=
inputs
[
1
].
shape
();
const
TensorShape
&
filter
=
outputs
[
0
].
shape
();
checkShape
(
input
,
filter
,
output
);
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
check
(
inputs
,
outputs
);
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
input
=
inputs
[
1
].
shape
();
const
TensorShape
&
filter
=
outputs
[
0
].
shape
();
size_t
batchSize
=
input
[
0
];
size_t
inputChannels
=
input
[
1
];
size_t
inputHeight
=
input
[
2
];
size_t
inputWidth
=
input
[
3
];
size_t
filterHeight
=
getFilterHeight
(
filter
);
size_t
filterWidth
=
getFilterWidth
(
filter
);
size_t
outputChannels
=
output
[
1
];
size_t
outputHeight
=
output
[
2
];
size_t
outputWidth
=
output
[
3
];
size_t
filterMultiplier
=
outputChannels
/
groups_
;
CHECK_EQ
(
inputChannels
,
groups_
);
real
*
outputGrad
=
inputs
[
0
].
data
<
real
>
();
real
*
inputData
=
inputs
[
1
].
data
<
real
>
();
real
*
filterGrad
=
outputs
[
0
].
data
<
real
>
();
int
size
=
outputChannels
*
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
;
resizeBuffer
<
Device
>
(
size
);
real
*
colData
=
reinterpret_cast
<
real
*>
(
memory_
->
getBuf
());
DepthwiseConvGradFilterFunctor
<
Device
,
real
>
depthwiseConvGradFilter
;
depthwiseConvGradFilter
(
outputGrad
,
inputData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputChannels
,
inputHeight
,
inputWidth
,
filterMultiplier
,
filterHeight
,
filterWidth
,
strideH
(),
strideW
(),
paddingH
(),
paddingW
(),
colData
,
filterGrad
);
}
};
REGISTER_TYPED_FUNC
(
DepthwiseConv
,
CPU
,
DepthwiseConvFunction
);
REGISTER_TYPED_FUNC
(
DepthwiseConvGradInput
,
CPU
,
DepthwiseConvGradInputFunction
);
REGISTER_TYPED_FUNC
(
DepthwiseConvGradFilter
,
CPU
,
DepthwiseConvGradFilterFunction
);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC
(
DepthwiseConv
,
GPU
,
DepthwiseConvFunction
);
REGISTER_TYPED_FUNC
(
DepthwiseConvGradInput
,
GPU
,
DepthwiseConvGradInputFunction
);
REGISTER_TYPED_FUNC
(
DepthwiseConvGradFilter
,
GPU
,
DepthwiseConvGradFilterFunction
);
#endif
}
// namespace paddle
paddle/function/DepthwiseConvOp.h
0 → 100644
浏览文件 @
91d2a57a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "TensorType.h"
namespace
paddle
{
/**
*\brief Depthwise convolution forward. The outputData
* of depthwise convolution is same with ExpandConvLayer
* when groups equals inputChannels in ExpandConvLayer.
*
* \param[in] inputData input data.
* \param[in] filterData the Paramters of the depthwise conv layer..
* \param[in] batchSize batch size of input data.
* \param[in] outputChannels channels of outputData.
* \param[in] outputHeight height of outputData.
* \param[in] outputWidth width of outputData.
* \param[in] inputChannels channels of inputData.
* \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData..
* \param[in] filterMultiplier equals to outputChannels/groups_.
* \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction.
* \param[in] strideW stride size in width direction.
* \param[in] paddingH padding size in height direction.
* \param[in] paddingW padding size in width direction.
* \param[out] outputData outputData.
*
*/
template
<
DeviceType
Device
,
class
T
>
class
DepthwiseConvFunctor
{
public:
void
operator
()(
const
T
*
inputData
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
outputData
);
};
/**
*\brief Functor tot compute the depthwise convolution backprop w.r.t input.
*
*
* \param[in] outputGradData the grad data of output.
* \param[in] filterData the Paramters of the depthwise conv layer..
* \param[in] batchSize batch size of input data.
* \param[in] outputChannels channels of outputData.
* \param[in] outputHeight height of outputData.
* \param[in] outputWidth width of outputData.
* \param[in] inputChannels channels of input data.
* \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData.
* \param[in] filterMultiplier equals to outputChannels/groups_.
* \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction.
* \param[in] strideW stride size in width direction.
* \param[in] paddingH padding size in height direction.
* \param[in] paddingW padding size in width direction.
* \param[out] inputGrad the grad data of input.
*
*/
template
<
DeviceType
Device
,
class
T
>
class
DepthwiseConvGradInputFunctor
{
public:
void
operator
()(
const
T
*
outputGrad
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
inputGrad
);
};
/**
*\brief Functor tot compute the depthwise convolution backprop w.r.t filter.
*
* \param[in] outputGradData the grad data of output.
* \param[in] inputData inputData.
* \param[in] batchSize batch size of input data.
* \param[in] outputChannels channels of outputData.
* \param[in] outputHeight height of outputData.
* \param[in] outputWidth width of outputData.
* \param[in] inputChannels channels of input data.
* \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData.
* \param[in] filterMultiplier equals to outputChannels/groups_.
* \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction.
* \param[in] strideW stride size in width direction.
* \param[in] paddingH padding size in height direction.
* \param[in] paddingW padding size in width direction.
* \param[in] colData Auxiliary data when calculating filterGrad.
* \param[in] multiplierData Auxiliary data when calculating filterGrad.
* \param[out] filterGrad the grad data of filter.
*
*/
template
<
DeviceType
Device
,
class
T
>
class
DepthwiseConvGradFilterFunctor
{
public:
void
operator
()(
const
T
*
outputGrad
,
const
T
*
inputData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
colData
,
T
*
filterGrad
);
};
}
// namespace paddle
paddle/function/DepthwiseConvOpGpu.cu
0 → 100644
浏览文件 @
91d2a57a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DepthwiseConvOp.h"
#include "GemmFunctor.h"
#include "paddle/math/BaseMatrix.h"
namespace
paddle
{
// CUDA kernel to compute the depthwise convolution forward pass
template
<
class
T
>
__global__
void
ConvolutionDepthwiseForward
(
const
int
nthreads
,
const
T
*
const
inputData
,
const
T
*
const
filterData
,
const
int
batchSize
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
inputWidth
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
filterWidth
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
T
*
const
outputData
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
batch
=
index
/
outputChannels
/
outputHeight
/
outputWidth
;
const
int
c_out
=
(
index
/
outputHeight
/
outputWidth
)
%
outputChannels
;
const
int
h_out
=
(
index
/
outputWidth
)
%
outputHeight
;
const
int
w_out
=
index
%
outputWidth
;
const
int
c_in
=
c_out
/
filterMultiplier
;
const
T
*
weight
=
filterData
+
c_out
*
filterHeight
*
filterWidth
;
T
value
=
0
;
const
int
h_in_start
=
-
paddingH
+
h_out
*
strideH
;
const
int
w_in_start
=
-
paddingW
+
w_out
*
strideW
;
const
int
h_in_end
=
-
paddingH
+
h_out
*
strideH
+
filterHeight
-
1
;
const
int
w_in_end
=
-
paddingW
+
w_out
*
strideW
+
filterWidth
-
1
;
if
((
h_in_start
>=
0
)
&&
(
h_in_end
<
inputHeight
)
&&
(
w_in_start
>=
0
)
&&
(
w_in_end
<
inputWidth
))
{
for
(
int
kh
=
0
;
kh
<
filterHeight
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filterWidth
;
++
kw
)
{
const
int
h_in
=
-
paddingH
+
h_out
*
strideH
+
kh
;
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
const
int
offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputHeight
+
h_in
)
*
inputWidth
+
w_in
;
value
+=
(
*
weight
)
*
inputData
[
offset
];
++
weight
;
}
}
}
else
{
for
(
int
kh
=
0
;
kh
<
filterHeight
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filterWidth
;
++
kw
)
{
const
int
h_in
=
-
paddingH
+
h_out
*
strideH
+
kh
;
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
if
((
h_in
>=
0
)
&&
(
h_in
<
inputHeight
)
&&
(
w_in
>=
0
)
&&
(
w_in
<
inputWidth
))
{
const
int
offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputHeight
+
h_in
)
*
inputWidth
+
w_in
;
value
+=
(
*
weight
)
*
inputData
[
offset
];
}
++
weight
;
}
}
}
outputData
[
index
]
=
value
;
}
}
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
template
<
class
T
>
__global__
void
ConvolutionDepthwiseInputBackward
(
const
int
nthreads
,
const
T
*
const
top_diff
,
const
T
*
const
weight_data
,
const
int
num
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
inputWidth
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
filterWidth
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
T
*
const
bottom_diff
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
batch
=
index
/
inputChannels
/
inputHeight
/
inputWidth
;
const
int
c_in
=
(
index
/
inputHeight
/
inputWidth
)
%
inputChannels
;
const
int
h_in
=
(
index
/
inputWidth
)
%
inputHeight
;
const
int
w_in
=
index
%
inputWidth
;
const
int
c_out_start
=
c_in
*
filterMultiplier
;
int
h_out_start
=
(
h_in
-
filterHeight
+
paddingH
+
strideH
)
/
strideH
;
h_out_start
=
0
>
h_out_start
?
0
:
h_out_start
;
int
h_out_end
=
(
h_in
+
paddingH
)
/
strideH
;
h_out_end
=
outputHeight
-
1
<
h_out_end
?
outputHeight
-
1
:
h_out_end
;
int
w_out_start
=
(
w_in
-
filterWidth
+
paddingW
+
strideW
)
/
strideW
;
w_out_start
=
0
>
w_out_start
?
0
:
w_out_start
;
int
w_out_end
=
(
w_in
+
paddingW
)
/
strideW
;
w_out_end
=
outputWidth
-
1
<
w_out_end
?
outputWidth
-
1
:
w_out_end
;
T
value
=
0
;
for
(
int
c_out
=
c_out_start
;
c_out
<
c_out_start
+
filterMultiplier
;
c_out
++
)
{
for
(
int
h_out
=
h_out_start
;
h_out
<=
h_out_end
;
++
h_out
)
{
const
int
filter_h
=
h_in
+
paddingH
-
h_out
*
strideH
;
for
(
int
w_out
=
w_out_start
;
w_out
<=
w_out_end
;
++
w_out
)
{
const
int
filter_w
=
w_in
+
paddingW
-
w_out
*
strideW
;
const
int
filter_offset
=
c_out
*
filterHeight
*
filterWidth
+
filter_h
*
filterWidth
+
filter_w
;
const
int
top_diff_offset
=
((
batch
*
outputChannels
+
c_out
)
*
outputHeight
+
h_out
)
*
outputWidth
+
w_out
;
value
+=
top_diff
[
top_diff_offset
]
*
weight_data
[
filter_offset
];
}
}
}
bottom_diff
[
index
]
+=
value
;
}
}
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
template
<
class
T
>
__global__
void
ConvolutionDepthwiseFilterBackward
(
const
int
num_i
,
const
int
nthreads
,
const
T
*
const
top_diff
,
const
T
*
const
inputData
,
const
int
num
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
inputWidth
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
filterWidth
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
T
*
const
buffer_data
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
h_out
=
(
index
/
outputWidth
)
%
outputHeight
;
const
int
w_out
=
index
%
outputWidth
;
const
int
kh
=
(
index
/
filterWidth
/
outputHeight
/
outputWidth
)
%
filterHeight
;
const
int
kw
=
(
index
/
outputHeight
/
outputWidth
)
%
filterWidth
;
const
int
h_in
=
-
paddingH
+
h_out
*
strideH
+
kh
;
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
if
((
h_in
>=
0
)
&&
(
h_in
<
inputHeight
)
&&
(
w_in
>=
0
)
&&
(
w_in
<
inputWidth
))
{
const
int
c_out
=
index
/
(
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
);
const
int
c_in
=
c_out
/
filterMultiplier
;
const
int
batch
=
num_i
;
const
int
top_offset
=
((
batch
*
outputChannels
+
c_out
)
*
outputHeight
+
h_out
)
*
outputWidth
+
w_out
;
const
int
bottom_offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputHeight
+
h_in
)
*
inputWidth
+
w_in
;
buffer_data
[
index
]
=
top_diff
[
top_offset
]
*
inputData
[
bottom_offset
];
}
else
{
buffer_data
[
index
]
=
0
;
}
}
}
template
<
class
T
>
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
const
T
*
inputData
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
outputData
){
int
outputSize
=
batchSize
*
outputChannels
*
outputHeight
*
outputWidth
;
size_t
blocks
=
(
outputSize
+
1024
-
1
)
/
1024
;
size_t
blockX
=
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
ConvolutionDepthwiseForward
<
T
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
outputSize
,
inputData
,
filterData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputChannels
,
inputHeight
,
inputWidth
,
filterMultiplier
,
filterHeight
,
filterWidth
,
strideH
,
strideW
,
paddingH
,
paddingW
,
outputData
);
}
};
template
<
class
T
>
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
const
T
*
outputGrad
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
inputGrad
){
int
inputSize
=
batchSize
*
inputChannels
*
inputHeight
*
inputWidth
;
size_t
blocks
=
(
inputSize
+
1024
-
1
)
/
1024
;
size_t
blockX
=
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
ConvolutionDepthwiseInputBackward
<
T
>
// NOLINT_NEXT_LINE(whitespace/operators)
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
inputSize
,
outputGrad
,
filterData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputChannels
,
inputHeight
,
inputWidth
,
filterMultiplier
,
filterHeight
,
filterWidth
,
strideH
,
strideW
,
paddingH
,
paddingW
,
inputGrad
);
}
};
template
<
class
T
>
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
const
T
*
outputGrad
,
const
T
*
inputData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterMultiplier
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
colData
,
T
*
filterGrad
){
int
colDataSize
=
outputChannels
*
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
;
size_t
blocks
=
(
colDataSize
+
1024
-
1
)
/
1024
;
size_t
blockX
=
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
BaseMatrix
filterGradMatrix
(
outputChannels
*
filterHeight
*
filterWidth
,
1
,
filterGrad
,
false
,
true
);
for
(
int
i
=
0
;
i
<
batchSize
;
i
++
)
{
ConvolutionDepthwiseFilterBackward
<
T
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
i
,
colDataSize
,
outputGrad
,
inputData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputChannels
,
inputHeight
,
inputWidth
,
filterMultiplier
,
filterHeight
,
filterWidth
,
strideH
,
strideW
,
paddingH
,
paddingW
,
colData
);
int
K
=
outputHeight
*
outputWidth
;
int
M
=
colDataSize
/
K
;
BaseMatrix
colMatrix
(
M
,
K
,
colData
,
false
,
true
);
filterGradMatrix
.
sumRows
(
colMatrix
,
(
T
)
1.0
,
(
T
)
1.0
);
}
}
};
#ifdef PADDLE_TYPE_DOUBLE
template
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
double
>;
#else
template
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
float
>;
#endif
}
// namespace paddle
paddle/gserver/layers/ExpandConvLayer.cpp
浏览文件 @
91d2a57a
...
...
@@ -38,10 +38,25 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
inputShape_
.
resize
(
numInputs
);
filterShape_
.
resize
(
numInputs
);
outputShape_
.
resize
(
numInputs
);
std
::
string
convType
;
std
::
string
convGradInputType
;
std
::
string
convGradFilterType
;
for
(
int
i
=
0
;
i
<
config_
.
inputs_size
();
i
++
)
{
std
::
vector
<
size_t
>
paddings
=
{(
size_t
)
paddingY_
[
i
],
(
size_t
)
padding_
[
i
]};
std
::
vector
<
size_t
>
strides
=
{(
size_t
)
strideY_
[
i
],
(
size_t
)
stride_
[
i
]};
if
(
useGpu_
&&
(
size_t
)
groups_
[
i
]
==
(
size_t
)
channels_
[
i
]
&&
!
isDeconv_
)
{
convType
=
"DepthwiseConv"
;
convGradInputType
=
"DepthwiseConvGradInput"
;
convGradFilterType
=
"DepthwiseConvGradFilter"
;
}
else
{
convType
=
"GemmConv"
;
convGradInputType
=
"GemmConvGradInput"
;
convGradFilterType
=
"GemmConvGradFilter"
;
}
if
(
FLAGS_use_nnpack
)
{
CHECK_EQ
(
isDeconv_
,
false
);
createFunction
(
forward_
,
...
...
@@ -53,21 +68,21 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
.
set
(
"algo"
,
std
::
string
(
"auto"
)));
}
else
{
createFunction
(
forward_
,
!
isDeconv_
?
"GemmConv"
:
"GemmConvGradInput"
,
!
isDeconv_
?
convType
:
convGradInputType
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
groups_
[
i
]));
createFunction
(
backward_
,
!
isDeconv_
?
"GemmConvGradInput"
:
"GemmConv"
,
!
isDeconv_
?
convGradInputType
:
convType
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
groups_
[
i
]));
createFunction
(
backward_
,
"GemmConvGradFilter"
,
convGradFilterType
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
...
...
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
91d2a57a
...
...
@@ -347,6 +347,55 @@ TEST(Layer, CosSimVecMatLayer) {
}
}
void
testDepthwiseConvLayer
(
const
string
&
type
,
bool
useGpu
)
{
TestConfig
config
;
config
.
biasSize
=
32
;
config
.
layerConfig
.
set_type
(
type
);
config
.
layerConfig
.
set_num_filters
(
32
);
config
.
layerConfig
.
set_partial_sum
(
1
);
config
.
layerConfig
.
set_shared_biases
(
true
);
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
2048
,
192
});
LayerInputConfig
*
input
=
config
.
layerConfig
.
add_inputs
();
ConvConfig
*
conv
=
input
->
mutable_conv_conf
();
conv
->
set_filter_size
(
2
);
conv
->
set_filter_size_y
(
3
);
conv
->
set_channels
(
16
);
conv
->
set_padding
(
0
);
conv
->
set_padding_y
(
1
);
conv
->
set_stride
(
2
);
conv
->
set_stride_y
(
2
);
conv
->
set_groups
(
16
);
conv
->
set_filter_channels
(
conv
->
channels
()
/
conv
->
groups
());
conv
->
set_img_size
(
16
);
conv
->
set_img_size_y
(
8
);
conv
->
set_output_x
(
outputSize
(
conv
->
img_size
(),
conv
->
filter_size
(),
conv
->
padding
(),
conv
->
stride
(),
/* caffeMode */
true
));
conv
->
set_output_y
(
outputSize
(
conv
->
img_size_y
(),
conv
->
filter_size_y
(),
conv
->
padding_y
(),
conv
->
stride_y
(),
/* caffeMode */
true
));
config
.
layerConfig
.
set_size
(
conv
->
output_x
()
*
conv
->
output_y
()
*
config
.
layerConfig
.
num_filters
());
testLayerGrad
(
config
,
"depthwise_conv"
,
100
,
false
,
useGpu
);
// Use small batch_size and useWeight=true to test biasGrad
testLayerGrad
(
config
,
"depthwise_conv"
,
2
,
false
,
useGpu
,
true
,
0.02
);
}
TEST
(
Layer
,
depthwiseConvLayer
)
{
// 'depthwise_conv' is a sepecial case of 'exconv' whose
// groups size equals to the input channels size.
testDepthwiseConvLayer
(
"exconv"
,
/* useGpu= */
false
);
#ifndef PADDLE_ONLY_CPU
testDepthwiseConvLayer
(
"exconv"
,
/* useGpu= */
true
);
#endif
}
void
testConvLayer
(
const
string
&
type
,
bool
trans
,
bool
useGpu
)
{
TestConfig
config
;
config
.
biasSize
=
16
;
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
91d2a57a
...
...
@@ -3219,6 +3219,10 @@ def ParameterHook(type, **kwargs):
if
sparsity_ratio
is
not
None
:
hook
.
sparsity_ratio
=
sparsity_ratio
return
hook
elif
type
==
'dpruning'
:
hook
=
ParameterUpdaterHookConfig
()
hook
.
type
=
type
return
hook
else
:
return
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录