未验证 提交 ff568afa 编写于 作者: F furnace 提交者: GitHub

[NPU] add npu support for conv3d and conv3d_grad (#38480)

* [NPU] add npu support for conv3d and conv3d_grad

* [NPU] delete failed unittests due to Ascend not support

* [NPU] delete debug codes

* [NPU] optimize codes, notest

* [NPU] remove const_cast

* [NPU] optimize for remove const_cast

* [NPU] fix written errors
上级 81e505df
......@@ -390,6 +390,204 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> {
}
}
};
template <typename T>
class NPUConv3dKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_EQ(data_format, "NCDHW",
platform::errors::Unimplemented(
"the data_format must be NCDHW in "
"the npu kernel of conv3d, but got data_format "
"= [%s]",
data_format));
PADDLE_ENFORCE_EQ(groups, 1, platform::errors::Unimplemented(
"the groups must be 1 in "
"the npu kernel of conv3d, but got groups "
"= [%d]",
groups));
output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<NPUDeviceContext>();
auto input_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(input->dims(), dev_ctx);
auto filter_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(filter->dims(), dev_ctx);
auto output_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(output->dims(), dev_ctx);
input_tensor.ShareDataWith(*input);
filter_tensor.ShareDataWith(*filter);
output_tensor.ShareDataWith(*output);
input_tensor.set_layout(DataLayout::kNCDHW);
filter_tensor.set_layout(DataLayout::kNCDHW);
output_tensor.set_layout(DataLayout::kNCDHW);
// update padding and dilation
auto in_dims = input->dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
framework::DDim filter_data_dims;
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int> strides_vec(5, 1);
std::vector<int> dilations_vec(5, 1);
strides_vec[2] = strides[0];
strides_vec[3] = strides[1];
strides_vec[4] = strides[2];
dilations_vec[2] = dilations[0];
dilations_vec[3] = dilations[1];
dilations_vec[4] = dilations[2];
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
const auto& runner =
NpuOpRunner("Conv3D", {input_tensor, filter_tensor}, {output_tensor},
{{"strides", strides_vec},
{"pads", paddings},
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(stream);
}
};
template <typename T>
class NPUConv3dGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_EQ(data_format, "NCDHW",
platform::errors::Unimplemented(
"the data_format must be NCDHW in "
"the npu kernel of conv3d, but got data_format "
"= [%s]",
data_format));
PADDLE_ENFORCE_EQ(groups, 1, platform::errors::Unimplemented(
"the groups must be 1 in "
"the npu kernel of conv3d, but got groups "
"= [%d]",
groups));
auto& dev_ctx = ctx.template device_context<NPUDeviceContext>();
auto input_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(input->dims(), dev_ctx);
auto filter_tensor =
ctx.AllocateTmpTensor<T, NPUDeviceContext>(filter->dims(), dev_ctx);
auto output_grad_tensor = ctx.AllocateTmpTensor<T, NPUDeviceContext>(
output_grad->dims(), dev_ctx);
input_tensor.ShareDataWith(*input);
filter_tensor.ShareDataWith(*filter);
output_grad_tensor.ShareDataWith(*output_grad);
input_tensor.set_layout(DataLayout::kNCDHW);
filter_tensor.set_layout(DataLayout::kNCDHW);
output_grad_tensor.set_layout(DataLayout::kNCDHW);
// update padding and dilation
auto in_dims = input->dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
framework::DDim filter_data_dims;
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int> strides_vec(5, 1);
std::vector<int> dilations_vec(5, 1);
strides_vec[2] = strides[0];
strides_vec[3] = strides[1];
strides_vec[4] = strides[2];
dilations_vec[2] = dilations[0];
dilations_vec[3] = dilations[1];
dilations_vec[4] = dilations[2];
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace());
std::vector<int> filter_shape_vec = phi::vectorize<int>(filter->dims());
Tensor filter_grad_tensor = ctx.AllocateTmpTensor<T, NPUDeviceContext>(
filter_grad->dims(), dev_ctx);
filter_grad_tensor.ShareDataWith(*filter_grad);
filter_grad_tensor.set_layout(DataLayout::kNCDHW);
const auto& runner = NpuOpRunner(
"Conv3DBackpropFilterD", {input_tensor, output_grad_tensor},
{filter_grad_tensor}, {{"filter_size", filter_shape_vec},
{"strides", strides_vec},
{"pads", paddings},
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(stream);
}
if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace());
std::vector<int> input_shape_vec = phi::vectorize<int>(input->dims());
Tensor input_grad_tensor = ctx.AllocateTmpTensor<T, NPUDeviceContext>(
input_grad->dims(), dev_ctx);
input_grad_tensor.ShareDataWith(*input_grad);
input_grad_tensor.set_layout(DataLayout::kNCDHW);
const auto& runner = NpuOpRunner(
"Conv3DBackpropInputD", {filter_tensor, output_grad_tensor},
{input_grad_tensor}, {{"input_size", input_shape_vec},
{"strides", strides_vec},
{"pads", paddings},
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -408,3 +606,9 @@ REGISTER_OP_NPU_KERNEL(conv2d, ops::NPUConvOpKernel<float>,
REGISTER_OP_NPU_KERNEL(conv2d_grad, ops::NPUConvGradOpKernel<float>,
ops::NPUConvGradOpKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(conv3d, ops::NPUConv3dKernel<float>,
ops::NPUConv3dKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(conv3d_grad, ops::NPUConv3dGradKernel<float>,
ops::NPUConv3dGradKernel<plat::float16>);
......@@ -47,6 +47,8 @@ static std::map<framework::proto::VarType::Type, aclDataType>
static std::map<DataLayout, aclFormat> DATA_LAYOUT_2_ACL_FORMAT = {
{DataLayout::kNCHW, ACL_FORMAT_NCHW},
{DataLayout::kNHWC, ACL_FORMAT_NHWC},
{DataLayout::kNCDHW, ACL_FORMAT_NCDHW},
{DataLayout::kNDHWC, ACL_FORMAT_NDHWC},
{DataLayout::kAnyLayout, ACL_FORMAT_ND},
};
......
......@@ -30,6 +30,8 @@ enum class DataLayout {
SPARSE_COO,
SPARSE_CSR,
NUM_DATA_LAYOUTS,
NDHWC,
NCDHW,
// See Note [ Why we need ALL in basic kernel key member? ]
ALL_LAYOUT = UNDEFINED,
// Note: Unify phi DataLayout and fluid::framework::DataLayout,
......@@ -43,6 +45,8 @@ enum class DataLayout {
kNHWC = NHWC,
kNCHW = NCHW,
kMKLDNN = MKLDNN, // all layouts supported by MKLDNN internally
kNDHWC = NDHWC,
kNCDHW = NCDHW,
};
} // namespace experimental
......@@ -70,6 +74,10 @@ inline DataLayout StringToDataLayout(const std::string& str) {
return DataLayout::SPARSE_COO;
} else if (s == "SPARSE_CSR") {
return DataLayout::SPARSE_CSR;
} else if (s == "NDHWC") {
return DataLayout::kNDHWC;
} else if (s == "NCDHW") {
return DataLayout::kNCDHW;
} else {
PD_THROW("Unknown data layout type string: ", s, ".");
}
......@@ -89,6 +97,10 @@ inline std::string DataLayoutToString(const DataLayout& layout) {
return "SPARSE_COO";
case DataLayout::SPARSE_CSR:
return "SPARSE_CSR";
case DataLayout::kNDHWC:
return "NDHWC";
case DataLayout::kNCDHW:
return "NCDHW";
default:
PD_THROW("Unknown Data Layout type ", static_cast<int>(layout), ".");
}
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
from test_conv3d_op import conv3d_forward_naive
paddle.enable_static()
def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent):
def init_paddings(self):
self.pad = [0, 0, 0]
self.padding_algorithm = "SAME"
cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp")
TestPaddingSMAECase.__name__ = cls_name
globals()[cls_name] = TestPaddingSMAECase
def create_test_padding_VALID_class(parent):
class TestPaddingVALIDCase(parent):
def init_paddings(self):
self.pad = [1, 1, 1]
self.padding_algorithm = "VALID"
cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp")
TestPaddingVALIDCase.__name__ = cls_name
globals()[cls_name] = TestPaddingVALIDCase
def create_test_channel_last_class(parent):
class TestChannelLastCase(parent):
def init_data_format(self):
self.data_format = "NDHWC"
def init_test_case_2(self):
N, C, D, H, W = self.input_size
self.input_size = [N, D, H, W, C]
cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast")
TestChannelLastCase.__name__ = cls_name
globals()[cls_name] = TestChannelLastCase
def create_test_fp16_class(parent):
class TestFp16Case(parent):
def init_dtype(self):
self.dtype = np.float16
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestFp16Case.__name__ = cls_name
globals()[cls_name] = TestFp16Case
class TestConv3DOp(OpTest):
def setUp(self):
self.op_type = "conv3d"
self.set_npu()
self.init_dtype()
self.init_data_format()
self.init_group()
self.init_dilation()
self.init_test_case()
conv3d_param = {
'stride': self.stride,
'pad': self.pad,
'dilations': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv3d_forward_naive(
input,
filter,
self.groups,
conv3d_param, ).astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations,
'data_format': self.data_format
}
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-2)
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, {'Input', 'Filter'},
'Output',
max_relative_error=0.03,
numeric_place=paddle.CPUPlace())
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']),
numeric_place=paddle.CPUPlace())
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']),
numeric_place=paddle.CPUPlace())
def set_npu(self):
self.__class__.use_npu = True
self.place = fluid.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def init_data_format(self):
self.data_format = "NCDHW"
def init_group(self):
self.groups = 1
def init_dilation(self):
self.dilations = [1, 1, 1]
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
class TestCase1(TestConv3DOp):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
# ---- test asymmetric padding ----
class TestConv3DOp_2(OpTest):
def setUp(self):
self.op_type = "conv3d"
self.set_npu()
self.init_dtype()
self.init_data_format()
self.init_group()
self.init_dilation()
self.init_paddings()
self.init_test_case()
self.init_test_case_2()
conv3d_param = {
'stride': self.stride,
'pad': self.pad,
'dilations': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv3d_forward_naive(input, filter, self.groups, conv3d_param,
self.padding_algorithm,
self.data_format).astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'groups': self.groups,
'dilations': self.dilations,
'data_format': self.data_format
}
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output_with_place(paddle.NPUPlace(0), atol=1e-2)
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, {'Input', 'Filter'},
'Output',
max_relative_error=0.03,
numeric_place=paddle.CPUPlace())
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']),
numeric_place=paddle.CPUPlace())
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']),
numeric_place=paddle.CPUPlace())
def set_npu(self):
self.__class__.use_npu = True
self.place = fluid.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def init_data_format(self):
self.data_format = "NCDHW"
def init_group(self):
self.groups = 1
def init_dilation(self):
self.dilations = [1, 1, 1]
def init_paddings(self):
self.pad = [0, 0, 0]
self.padding_algorithm = "EXPLICIT"
def init_test_case(self):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_test_case_2(self):
pass
class TestConv3DOp_AsyPadding(TestConv3DOp_2):
def init_test_case(self):
self.stride = [1, 1, 2]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_paddings(self):
self.pad = [1, 0, 1, 0, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestConv3DOp_DiffDataInDiffDim(TestConv3DOp_2):
def init_test_case(self):
self.stride = [1, 1, 2]
self.input_size = [2, 3, 4, 5, 5] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 4, 3]
def init_paddings(self):
self.pad = [1, 0, 1, 0, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestCase1_AsyPadding(TestConv3DOp_2):
def init_test_case(self):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_paddings(self):
self.pad = [0, 0, 1, 0, 0, 2]
self.padding_algorithm = "EXPLICIT"
# --------- test python API ---------------
class TestConv3DAPI(unittest.TestCase):
def test_api(self):
input_NDHWC = fluid.layers.data(
name="input_NDHWC",
shape=[2, 5, 5, 5, 3],
append_batch_size=False,
dtype="float32")
input_NCDHW = fluid.layers.data(
name="input_NCDHW",
shape=[2, 3, 5, 5, 3],
append_batch_size=False,
dtype="float32")
fluid.layers.conv3d(
input=input_NDHWC,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=0,
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
fluid.layers.conv3d(
input=input_NCDHW,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=[1, 2, 1, 0, 1, 0],
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
fluid.layers.conv3d(
input=input_NCDHW,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=[[0, 0], [0, 0], [1, 1], [1, 1], [1, 1]],
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
fluid.layers.conv3d(
input=input_NDHWC,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=[[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]],
dilation=[1, 1, 1],
groups=1,
data_format="NDHWC")
fluid.layers.conv3d(
input=input_NCDHW,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding="SAME",
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
fluid.layers.conv3d(
input=input_NCDHW,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding="VALID",
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
class TestConv3DAPI_Error(unittest.TestCase):
def test_api(self):
input = fluid.layers.data(
name="input",
shape=[2, 5, 5, 5, 4],
append_batch_size=False,
dtype="float32")
# ValueError: cudnn
def run_1():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
use_cudnn=[0],
data_format="NCDHW")
self.assertRaises(ValueError, run_1)
# ValueError: data_format
def run_2():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=0,
dilation=[1, 1, 1],
groups=1,
use_cudnn=False,
data_format="NCHWC")
self.assertRaises(ValueError, run_2)
# ValueError: padding
def run_3():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=3,
stride=1,
padding="SAMEE",
dilation=1,
groups=1,
use_cudnn=False,
data_format="NCDHW")
self.assertRaises(ValueError, run_3)
def run_4():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=3,
stride=1,
padding=[[0, 1], [0, 0], [0, 1], [0, 1], [0, 1]],
dilation=1,
groups=1,
use_cudnn=False,
data_format="NCDHW")
self.assertRaises(ValueError, run_4)
def run_5():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=0,
stride=0,
padding=[[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]],
dilation=1,
groups=1,
use_cudnn=False,
data_format="NDHWC")
self.assertRaises(ValueError, run_5)
# ValueError: channel dimmention
x = fluid.layers.data(
name="x",
shape=[2, 5, 5, 5, -1],
append_batch_size=False,
dtype="float32")
def run_6():
fluid.layers.conv3d(
input=x,
num_filters=3,
filter_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
use_cudnn=False,
data_format="NDHWC")
self.assertRaises(ValueError, run_6)
# ValueError: groups
def run_7():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=3,
stride=1,
padding=0,
dilation=1,
groups=3,
use_cudnn=False,
data_format="NDHWC")
self.assertRaises(ValueError, run_7)
# ValueError: filter num
def run_8():
fluid.layers.conv3d(
input=input,
num_filters=0,
filter_size=0,
stride=0,
padding=0,
dilation=0,
groups=1,
use_cudnn=False,
data_format="NDHWC")
self.assertRaises(ValueError, run_8)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册