提交 e3342ff8 编写于 作者: W wanghaoshuang

Fix android build error.

上级 1dc850e4
......@@ -78,7 +78,9 @@ inline void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
int padding_height,
int padding_width,
int stride_height,
int stride_width) {}
int stride_width,
int dilation_h,
int dilation_w) {}
inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
hl_tensor_descriptor image,
......@@ -86,7 +88,9 @@ inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
int padding_height,
int padding_width,
int stride_height,
int stride_width) {}
int stride_width,
int dilation_h,
int dilation_w) {}
inline void hl_destroy_convolution_descriptor(hl_convolution_descriptor conv) {}
......@@ -99,7 +103,8 @@ inline void hl_conv_workspace(hl_tensor_descriptor input,
int* convBwdDataAlgo,
size_t* bwdDataLimitBytes,
int* convBwdFilterAlgo,
size_t* bwdFilterLimitBytes) {}
size_t* bwdFilterLimitBytes,
bool useDilation) {}
inline void hl_convolution_forward(hl_tensor_descriptor input,
real* input_data,
......
......@@ -640,7 +640,8 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
#else
if (dilation_h > 1 || dilation_w > 1) {
LOG(FATAL)
<< "Current cudnn version does't support for dilation convolution.";
<< "Current cuDNN version does't support for dilation convolution. "
<< "The dilation convolution requires cuDNN >= v6.0.";
}
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
......
......@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#include <cudnn.h>
#endif
#include <gtest/gtest.h>
#include <string>
#include <vector>
......@@ -262,8 +264,8 @@ TEST(Projection, conv) {
testProjectionConv(1, false);
testProjectionConv(3, false);
/// test ConvTransProjection
/// testProjectionConv(1, true);
/// testProjectionConv(3, true);
testProjectionConv(1, true);
testProjectionConv(3, true);
}
#endif
......
......@@ -862,7 +862,6 @@ class Conv(Cfg):
filter_size,
channels,
padding=None,
dilation=None,
stride=None,
groups=None,
filter_channels=None,
......@@ -871,8 +870,9 @@ class Conv(Cfg):
caffe_mode=True,
filter_size_y=None,
padding_y=None,
dilation_y=None,
stride_y=None):
stride_y=None,
dilation=None,
dilation_y=None):
self.add_keys(locals())
if filter_size_y is None:
self.filter_size_y = filter_size
......
......@@ -2340,7 +2340,7 @@ def img_conv_layer(input,
groups=1,
stride=1,
padding=0,
dilation=0,
dilation=1,
bias_attr=None,
param_attr=None,
shared_biases=True,
......@@ -2472,9 +2472,6 @@ def img_conv_layer(input,
else:
dilation_y = dilation
if dilation > 1 or dilation_y > 1:
assert layer_type in ["cudnn_conv", "cudnn_convt"]
if param_attr.attr.get('initial_smart'):
# special initial for conv layers.
init_w = (2.0 / (filter_size**2 * num_channels))**0.5
......@@ -2484,6 +2481,8 @@ def img_conv_layer(input,
param_attr.attr["initial_smart"] = False
if layer_type:
if dilation > 1 or dilation_y > 1:
assert layer_type in ["cudnn_conv", "cudnn_convt"]
if trans:
assert layer_type in ["exconvt", "cudnn_convt"]
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册