提交 e3342ff8 编写于 作者: W wanghaoshuang

Fix android build error.

上级 1dc850e4
...@@ -78,7 +78,9 @@ inline void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, ...@@ -78,7 +78,9 @@ inline void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
int padding_height, int padding_height,
int padding_width, int padding_width,
int stride_height, int stride_height,
int stride_width) {} int stride_width,
int dilation_h,
int dilation_w) {}
inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
hl_tensor_descriptor image, hl_tensor_descriptor image,
...@@ -86,7 +88,9 @@ inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, ...@@ -86,7 +88,9 @@ inline void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
int padding_height, int padding_height,
int padding_width, int padding_width,
int stride_height, int stride_height,
int stride_width) {} int stride_width,
int dilation_h,
int dilation_w) {}
inline void hl_destroy_convolution_descriptor(hl_convolution_descriptor conv) {} inline void hl_destroy_convolution_descriptor(hl_convolution_descriptor conv) {}
...@@ -99,7 +103,8 @@ inline void hl_conv_workspace(hl_tensor_descriptor input, ...@@ -99,7 +103,8 @@ inline void hl_conv_workspace(hl_tensor_descriptor input,
int* convBwdDataAlgo, int* convBwdDataAlgo,
size_t* bwdDataLimitBytes, size_t* bwdDataLimitBytes,
int* convBwdFilterAlgo, int* convBwdFilterAlgo,
size_t* bwdFilterLimitBytes) {} size_t* bwdFilterLimitBytes,
bool useDilation) {}
inline void hl_convolution_forward(hl_tensor_descriptor input, inline void hl_convolution_forward(hl_tensor_descriptor input,
real* input_data, real* input_data,
......
...@@ -640,7 +640,8 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, ...@@ -640,7 +640,8 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
#else #else
if (dilation_h > 1 || dilation_w > 1) { if (dilation_h > 1 || dilation_w > 1) {
LOG(FATAL) LOG(FATAL)
<< "Current cudnn version does't support for dilation convolution."; << "Current cuDNN version does't support for dilation convolution. "
<< "The dilation convolution requires cuDNN >= v6.0.";
} }
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc, CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
......
...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#include <cudnn.h> #include <cudnn.h>
#endif
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -262,8 +264,8 @@ TEST(Projection, conv) { ...@@ -262,8 +264,8 @@ TEST(Projection, conv) {
testProjectionConv(1, false); testProjectionConv(1, false);
testProjectionConv(3, false); testProjectionConv(3, false);
/// test ConvTransProjection /// test ConvTransProjection
/// testProjectionConv(1, true); testProjectionConv(1, true);
/// testProjectionConv(3, true); testProjectionConv(3, true);
} }
#endif #endif
......
...@@ -862,7 +862,6 @@ class Conv(Cfg): ...@@ -862,7 +862,6 @@ class Conv(Cfg):
filter_size, filter_size,
channels, channels,
padding=None, padding=None,
dilation=None,
stride=None, stride=None,
groups=None, groups=None,
filter_channels=None, filter_channels=None,
...@@ -871,8 +870,9 @@ class Conv(Cfg): ...@@ -871,8 +870,9 @@ class Conv(Cfg):
caffe_mode=True, caffe_mode=True,
filter_size_y=None, filter_size_y=None,
padding_y=None, padding_y=None,
dilation_y=None, stride_y=None,
stride_y=None): dilation=None,
dilation_y=None):
self.add_keys(locals()) self.add_keys(locals())
if filter_size_y is None: if filter_size_y is None:
self.filter_size_y = filter_size self.filter_size_y = filter_size
......
...@@ -2340,7 +2340,7 @@ def img_conv_layer(input, ...@@ -2340,7 +2340,7 @@ def img_conv_layer(input,
groups=1, groups=1,
stride=1, stride=1,
padding=0, padding=0,
dilation=0, dilation=1,
bias_attr=None, bias_attr=None,
param_attr=None, param_attr=None,
shared_biases=True, shared_biases=True,
...@@ -2472,9 +2472,6 @@ def img_conv_layer(input, ...@@ -2472,9 +2472,6 @@ def img_conv_layer(input,
else: else:
dilation_y = dilation dilation_y = dilation
if dilation > 1 or dilation_y > 1:
assert layer_type in ["cudnn_conv", "cudnn_convt"]
if param_attr.attr.get('initial_smart'): if param_attr.attr.get('initial_smart'):
# special initial for conv layers. # special initial for conv layers.
init_w = (2.0 / (filter_size**2 * num_channels))**0.5 init_w = (2.0 / (filter_size**2 * num_channels))**0.5
...@@ -2484,6 +2481,8 @@ def img_conv_layer(input, ...@@ -2484,6 +2481,8 @@ def img_conv_layer(input,
param_attr.attr["initial_smart"] = False param_attr.attr["initial_smart"] = False
if layer_type: if layer_type:
if dilation > 1 or dilation_y > 1:
assert layer_type in ["cudnn_conv", "cudnn_convt"]
if trans: if trans:
assert layer_type in ["exconvt", "cudnn_convt"] assert layer_type in ["exconvt", "cudnn_convt"]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册