未验证 提交 d3e06a51 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix depthwise_conv2d_grad, test=develop (#35626)

* [NPU] fix depthwise_conv2d_grad, test=develop

* remove debug files, test=develop
上级 2367cca6
......@@ -19,33 +19,26 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext;
template <typename DeviceContext, typename T>
template <typename T>
class DepthwiseConvNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// input
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* filter = context.Input<Tensor>("Filter");
// output
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
// attr
const std::vector<int> stride = context.Attr<std::vector<int>>("strides");
std::vector<int> padding = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilation = context.Attr<std::vector<int>>("dilations");
const std::string data_format = context.Attr<std::string>("data_format");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
// npu stream
auto stream =
context.template device_context<platform::NPUDeviceContext>().stream();
const std::vector<int> stride = ctx.Attr<std::vector<int>>("strides");
std::vector<int> padding = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilation = ctx.Attr<std::vector<int>>("dilations");
const std::string data_format = ctx.Attr<std::string>("data_format");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
// check dimension
const bool channel_last = data_format == "NHWC";
if (channel_last) {
// NHWC
PADDLE_ENFORCE_EQ(
output->dims()[output->dims().size() - 1],
input->dims()[input->dims().size() - 1],
......@@ -56,7 +49,6 @@ class DepthwiseConvNPUKernel : public framework::OpKernel<T> {
output->dims()[output->dims().size() - 1],
input->dims()[input->dims().size() - 1]));
} else {
// NCHW
PADDLE_ENFORCE_EQ(
output->dims()[1], input->dims()[1],
platform::errors::InvalidArgument(
......@@ -66,7 +58,6 @@ class DepthwiseConvNPUKernel : public framework::OpKernel<T> {
output->dims()[1], input->dims()[1]));
}
// update padding and dilation
auto in_dims = input->dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
......@@ -83,17 +74,6 @@ class DepthwiseConvNPUKernel : public framework::OpKernel<T> {
UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm,
in_data_dims, stride, ksize);
// Transform filter (n, 1, h, w) --> (1, n, h, w)
Tensor transformed_filter(filter->type());
transformed_filter.mutable_data<T>({filter->dims()[1], filter->dims()[0],
filter->dims()[2], filter->dims()[3]},
context.device_context().GetPlace());
std::vector<int> perm = {1, 0, 2, 3};
const auto& runner_trans = NpuOpRunner(
"TransposeD", {*filter}, {transformed_filter}, {{"perm", perm}});
runner_trans.Run(stream);
// construct NPU attr
std::vector<int> strides(4, 1);
std::vector<int> dilations(4, 1);
......@@ -115,7 +95,18 @@ class DepthwiseConvNPUKernel : public framework::OpKernel<T> {
dilations[3] = dilation[1];
}
// CANN OP
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
// Transform filter (n, 1, h, w) --> (1, n, h, w)
Tensor transformed_filter(filter->type());
transformed_filter.mutable_data<T>({filter->dims()[1], filter->dims()[0],
filter->dims()[2], filter->dims()[3]},
ctx.device_context().GetPlace());
std::vector<int> perm = {1, 0, 2, 3};
const auto& runner_trans = NpuOpRunner(
"TransposeD", {*filter}, {transformed_filter}, {{"perm", perm}});
runner_trans.Run(stream);
const auto& runner =
NpuOpRunner("DepthwiseConv2D", {input_tensor, transformed_filter},
{output_tensor}, {{"strides", strides},
......@@ -129,27 +120,20 @@ class DepthwiseConvNPUKernel : public framework::OpKernel<T> {
template <typename T>
class DepthwiseConvGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// input
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* filter = context.Input<Tensor>("Filter");
// output
auto output_grad = context.Input<Tensor>(framework::GradVarName("Output"));
auto input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
auto filter_grad = context.Output<Tensor>(framework::GradVarName("Filter"));
// attr
const std::vector<int> stride = context.Attr<std::vector<int>>("strides");
std::vector<int> padding = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilation = context.Attr<std::vector<int>>("dilations");
const std::string data_format = context.Attr<std::string>("data_format");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
// npu stream
auto stream =
context.template device_context<platform::NPUDeviceContext>().stream();
const std::vector<int> stride = ctx.Attr<std::vector<int>>("strides");
std::vector<int> padding = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilation = ctx.Attr<std::vector<int>>("dilations");
const std::string data_format = ctx.Attr<std::string>("data_format");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
// check dimension
const bool channel_last = data_format == "NHWC";
// update padding and dilation
......@@ -169,11 +153,13 @@ class DepthwiseConvGradNPUKernel : public framework::OpKernel<T> {
UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm,
in_data_dims, stride, ksize);
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
// Transform filter (n, 1, h, w) --> (1, n, h, w)
Tensor transformed_filter(filter->type());
transformed_filter.mutable_data<T>({filter->dims()[1], filter->dims()[0],
filter->dims()[2], filter->dims()[3]},
context.device_context().GetPlace());
ctx.device_context().GetPlace());
std::vector<int> perm = {1, 0, 2, 3};
const auto& runner_trans = NpuOpRunner(
"TransposeD", {*filter}, {transformed_filter}, {{"perm", perm}});
......@@ -200,39 +186,52 @@ class DepthwiseConvGradNPUKernel : public framework::OpKernel<T> {
dilations[3] = dilation[1];
}
// LOG(INFO) << "strides = " << framework::make_ddim(strides).to_str();
// LOG(INFO) << "dilations = " << framework::make_ddim(dilations).to_str();
// LOG(INFO) << "padding = " << framework::make_ddim(padding).to_str();
// LOG(INFO) << "data_format = " << data_format;
if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace());
std::vector<int> filter_shape_vec =
framework::vectorize<int>(transformed_filter.dims());
filter_grad->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner(
"DepthwiseConv2DBackpropFilterD", {input_tensor, output_grad_tensor},
{*filter_grad}, {{"filter_size", filter_shape_vec},
{"strides", strides},
{"pads", padding},
{"dilations", dilations},
{"data_format", data_format}});
runner.Run(stream);
PADDLE_ENFORCE_EQ(
(dilations[2] == 1 && dilations[3] == 1), true,
platform::errors::InvalidArgument(
"dilation_h and dilation_w in DepthwiseConv2DBackpropFilterD "
"must be equal to 1, but got dilation_h %d, dilation_w %d",
dilation[2], dilation[3]));
NpuOpRunner runner;
runner.SetType("DepthwiseConv2DBackpropFilterD")
.AddInput(input_tensor)
.AddInput(output_grad_tensor)
.AddOutput(*filter_grad)
.AddAttr("filter_size",
framework::vectorize(transformed_filter.dims()))
.AddAttr("strides", strides)
.AddAttr("dilations", dilations)
.AddAttr("pads", padding)
.AddAttr("data_format", data_format)
.Run(stream);
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
std::vector<int> input_shape_vec =
framework::vectorize<int>(input->dims());
input_grad->mutable_data<T>(ctx.GetPlace());
Tensor input_grad_tensor;
input_grad_tensor.ShareDataWith(*input_grad);
if (channel_last) {
input_grad_tensor.set_layout(DataLayout::kNHWC);
}
const auto& runner =
NpuOpRunner("DepthwiseConv2DBackpropInputD",
{transformed_filter, output_grad_tensor},
{input_grad_tensor}, {{"input_size", input_shape_vec},
{"strides", strides},
{"pads", padding},
{"dilations", dilations},
{"data_format", data_format}});
runner.Run(stream);
NpuOpRunner runner;
runner.SetType("DepthwiseConv2DBackpropInputD")
.AddInput(transformed_filter)
.AddInput(output_grad_tensor)
.AddOutput(input_grad_tensor)
.AddAttr("input_size", framework::vectorize(input->dims()))
.AddAttr("strides", strides)
.AddAttr("dilations", dilations)
.AddAttr("pads", padding)
.AddAttr("data_format", data_format)
.Run(stream);
}
}
};
......@@ -241,7 +240,6 @@ template <typename T>
class NPUConvOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::NPUDeviceContext>();
const Tensor* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
......@@ -293,6 +291,7 @@ class NPUConvOpKernel : public framework::OpKernel<T> {
dilations_vec[3] = dilations[1];
}
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
const auto& runner =
NpuOpRunner("Conv2D", {input_tensor, *filter}, {output_tensor},
{{"strides", strides_vec},
......@@ -300,7 +299,7 @@ class NPUConvOpKernel : public framework::OpKernel<T> {
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(dev_ctx.stream());
runner.Run(stream);
}
};
......@@ -308,8 +307,6 @@ template <typename T>
class NPUConvGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::NPUDeviceContext>();
auto input = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
......@@ -363,6 +360,7 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> {
dilations_vec[3] = dilations[1];
}
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace());
std::vector<int> filter_shape_vec =
......@@ -376,7 +374,7 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> {
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(dev_ctx.stream());
runner.Run(stream);
}
if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace());
......@@ -396,7 +394,7 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> {
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(dev_ctx.stream());
runner.Run(stream);
}
}
};
......@@ -404,15 +402,17 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(depthwise_conv2d, ops::DepthwiseConvNPUKernel<float>,
ops::DepthwiseConvNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(depthwise_conv2d_grad,
ops::DepthwiseConvGradNPUKernel<float>,
ops::DepthwiseConvGradNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(
depthwise_conv2d,
ops::DepthwiseConvNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
depthwise_conv2d_grad,
ops::DepthwiseConvGradNPUKernel<paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(conv2d, ops::NPUConvOpKernel<float>,
ops::NPUConvOpKernel<paddle::platform::float16>);
ops::NPUConvOpKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(conv2d_grad, ops::NPUConvGradOpKernel<float>,
ops::NPUConvGradOpKernel<paddle::platform::float16>);
ops::NPUConvGradOpKernel<plat::float16>);
......@@ -203,6 +203,14 @@ class Conv2D(layers.Layer):
else:
self._l_type = 'conv2d'
# NPU only supports depthwise_conv2d when "input_channel = output_channel = groups"
if core.is_compiled_with_npu():
if (self._num_channels == self._groups and
self._num_channels == self._num_filters):
l_type = 'depthwise_conv2d'
else:
l_type = 'conv2d'
self._num_channels = num_channels
if self._groups is None:
num_filter_channels = self._num_channels
......
......@@ -1547,6 +1547,13 @@ def conv2d(input,
core.is_compiled_with_rocm()):
l_type = 'depthwise_conv2d'
# NPU only supports depthwise_conv2d when "input_channel = output_channel = groups"
if core.is_compiled_with_npu():
if (num_channels == groups and num_channels == num_filters):
l_type = 'depthwise_conv2d'
else:
l_type = 'conv2d'
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
......
......@@ -20,7 +20,7 @@ import paddle
import paddle.fluid as fluid
import sys
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest
from test_conv2d_op import conv2d_forward_naive
from paddle import ParamAttr
from paddle.regularizer import L2Decay
......@@ -66,14 +66,22 @@ def create_test_padding_VALID_class(parent):
globals()[cls_name] = TestPaddingVALIDCase
@skip_check_grad_ci(
reason='''Inference only, it doesn't need to call check_grad.''')
def create_test_fp16_class(parent):
class TestFp16Case(parent):
def init_data_type(self):
self.dtype = np.float16
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestFp16Case.__name__ = cls_name
globals()[cls_name] = TestFp16Case
class TestDepthwiseConvNPU(OpTest):
def setUp(self):
self.op_type = "depthwise_conv2d"
self.dtype = np.float16
self.set_npu()
self.op_type = "depthwise_conv2d"
self.init_data_format()
self.init_data_type()
self.init_test_case()
self.init_test_case_2()
......@@ -114,18 +122,52 @@ class TestDepthwiseConvNPU(OpTest):
self.pad = [1, 1]
self.dilations = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.input_size = [2, 12, 5, 5] # NCHW
self.groups = 12
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.filter_size = [12, f_c, 3, 3]
def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, atol=1e-2)
def test_check_grad(self):
if self.dtype == np.float16:
return
if self.dilations[0] == 1 and self.dilations[1] == 1:
self.check_grad_with_place(
self.place, {'Input', 'Filter'},
'Output',
max_relative_error=0.03,
numeric_place=paddle.CPUPlace())
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Input'],
'Output',
no_grad_set=set(['Filter']),
max_relative_error=0.03,
numeric_place=paddle.CPUPlace())
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
if self.dilations[0] == 1 and self.dilations[1] == 1:
self.check_grad_with_place(
self.place, ['Filter'],
'Output',
no_grad_set=set(['Input']),
max_relative_error=0.03,
numeric_place=paddle.CPUPlace())
def init_data_format(self):
self.data_format = "NCHW"
def init_data_type(self):
self.dtype = np.float32
def init_test_case_2(self):
pass
......@@ -135,45 +177,44 @@ class TestDepthwiseConvNPU2(TestDepthwiseConvNPU):
self.pad = [1, 1]
self.dilations = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.input_size = [2, 12, 5, 5] # NCHW
self.groups = 12
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.filter_size = [12, f_c, 3, 3]
class TestDepthwiseConvNPU3(TestDepthwiseConvNPU):
def init_test_case(self):
self.pad = [1, 1]
self.dilations = [2, 2]
self.dilations = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.input_size = [2, 12, 5, 5] # NCHW
self.groups = 12
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.filter_size = [12, f_c, 3, 3]
class TestDepthwiseConvNPU4(TestDepthwiseConvNPU):
def init_test_case(self):
self.pad = [1, 1]
self.dilations = [2, 2]
self.dilations = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.input_size = [2, 12, 5, 5] # NCHW
self.groups = 12
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.filter_size = [12, f_c, 3, 3]
@skip_check_grad_ci(
reason='''Inference only, it doesn't need to call check_grad.''')
class TestDepthwiseConvNPU_Padding(OpTest):
def setUp(self):
self.op_type = "depthwise_conv2d"
self.dtype = np.float16
self.dtype = np.float32
self.set_npu()
self.init_data_format()
self.init_data_type()
self.init_paddings()
self.init_test_case()
self.init_test_case_2()
......@@ -215,18 +256,50 @@ class TestDepthwiseConvNPU_Padding(OpTest):
self.pad = [1, 1, 0, 1]
self.dilations = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.input_size = [2, 12, 5, 5] # NCHW
self.groups = 12
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.filter_size = [12, f_c, 3, 3]
def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, atol=1e-2)
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, {'Input', 'Filter'},
'Output',
max_relative_error=0.03,
numeric_place=paddle.CPUPlace())
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']),
numeric_place=paddle.CPUPlace())
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']),
numeric_place=paddle.CPUPlace())
def init_data_format(self):
self.data_format = "NCHW"
def init_data_type(self):
self.dtype = np.float32
def init_paddings(self):
self.pad = [1, 1, 0, 1]
self.padding_algorithm = "EXPLICIT"
......@@ -240,11 +313,11 @@ class TestDepthwiseConvNPU2_Padding(TestDepthwiseConvNPU_Padding):
self.pad = [1, 1, 0, 1]
self.dilations = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.input_size = [2, 12, 5, 5] # NCHW
self.groups = 12
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.filter_size = [12, f_c, 3, 3]
def init_paddings(self):
self.pad = [0, 1, 0, 2]
......@@ -256,11 +329,11 @@ class TestDepthwiseConvNPU3_Padding(TestDepthwiseConvNPU_Padding):
self.pad = [1, 1, 0, 1]
self.dilations = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.input_size = [2, 12, 5, 5] # NCHW
self.groups = 12
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.filter_size = [12, f_c, 3, 3]
def init_paddings(self):
self.pad = [2, 1, 2, 3]
......@@ -283,151 +356,11 @@ create_test_padding_VALID_class(TestDepthwiseConvNPU_Padding)
create_test_padding_VALID_class(TestDepthwiseConvNPU2_Padding)
create_test_padding_VALID_class(TestDepthwiseConvNPU3_Padding)
class TestDepthwiseConvNet(unittest.TestCase):
def __init__(self, methodName='runTest'):
super().__init__(methodName=methodName)
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(2, 4, 16, 16)).astype('float16')
b_np = np.random.random(size=(4, 1, 3, 3)).astype('float16')
if not run_npu:
a_np = a_np.astype('float32')
b_np = b_np.astype('float32')
label_np = np.random.randint(10, size=(2, 10)).astype('float32')
with paddle.static.program_guard(main_prog, startup_prog):
if run_npu:
a = paddle.static.data(
name="a", shape=[2, 4, 16, 16], dtype='float16')
b = paddle.static.data(
name="b", shape=[4, 1, 3, 3], dtype='float16')
else:
a = paddle.static.data(
name="a", shape=[2, 4, 16, 16], dtype='float32')
b = paddle.static.data(
name="b", shape=[4, 1, 3, 3], dtype='float32')
label = paddle.static.data(
name="label", shape=[2, 10], dtype='float32')
a *= 2.0
b += 0.01
fc_1 = paddle.nn.functional.conv2d(a, b, bias=None, groups=4)
if run_npu:
fc_1 = paddle.cast(fc_1, dtype='float32')
fc_1 = paddle.nn.functional.relu(fc_1)
prediction = fluid.layers.fc(input=fc_1, size=10, act='softmax')
cost = paddle.nn.functional.smooth_l1_loss(
input=prediction, label=label)
loss = paddle.sum(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.00001)
sgd.minimize(loss)
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
return pred_res, loss_res
def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-04, atol=1e-03))
self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-04, atol=1e-03))
class TestDepthwiseConvNet_NHWC(unittest.TestCase):
def __init__(self, methodName='runTest'):
super().__init__(methodName=methodName)
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(2, 16, 16, 4)).astype('float16')
b_np = np.random.random(size=(4, 1, 3, 3)).astype('float16')
if not run_npu:
a_np = a_np.astype('float32')
b_np = b_np.astype('float32')
label_np = np.random.randint(10, size=(2, 10)).astype('float32')
with paddle.static.program_guard(main_prog, startup_prog):
if run_npu:
a = paddle.static.data(
name="a", shape=[2, 16, 16, 4], dtype='float16')
b = paddle.static.data(
name="b", shape=[4, 1, 3, 3], dtype='float16')
else:
a = paddle.static.data(
name="a", shape=[2, 16, 16, 4], dtype='float32')
b = paddle.static.data(
name="b", shape=[4, 1, 3, 3], dtype='float32')
label = paddle.static.data(
name="label", shape=[2, 10], dtype='float32')
a *= 2.0
b += 0.01
fc_1 = paddle.nn.functional.conv2d(
a, b, bias=None, groups=4, data_format='NHWC')
if run_npu:
fc_1 = paddle.cast(fc_1, dtype='float32')
fc_1 = paddle.nn.functional.relu(fc_1)
prediction = fluid.layers.fc(input=fc_1, size=10, act='softmax')
cost = paddle.nn.functional.smooth_l1_loss(
input=prediction, label=label)
loss = paddle.sum(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.00001)
sgd.minimize(loss)
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
return pred_res, loss_res
def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-04, atol=1e-03))
self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-04, atol=1e-03))
create_test_fp16_class(TestDepthwiseConvNPU)
create_test_fp16_class(TestDepthwiseConvNPU2)
create_test_fp16_class(TestDepthwiseConvNPU_Padding)
create_test_fp16_class(TestDepthwiseConvNPU2_Padding)
create_test_fp16_class(TestDepthwiseConvNPU3_Padding)
if __name__ == '__main__':
unittest.main()
......@@ -127,24 +127,33 @@ class TestConv2DOp(OpTest):
self.check_output_with_place(fluid.NPUPlace(0), atol=1e-2)
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
fluid.NPUPlace(0), {'Input', 'Filter'},
'Output',
max_relative_error=0.03)
max_relative_error=0.03,
numeric_place=paddle.CPUPlace())
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
fluid.NPUPlace(0), ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
no_grad_set=set(['Filter']),
numeric_place=paddle.CPUPlace())
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
fluid.NPUPlace(0), ['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
no_grad_set=set(['Input']),
numeric_place=paddle.CPUPlace())
def init_test_case(self):
self.pad = [0, 0]
......@@ -310,23 +319,32 @@ class TestConv2DOp_v2(OpTest):
self.check_output_with_place(paddle.NPUPlace(0), atol=1e-2)
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
paddle.NPUPlace(0), {'Input', 'Filter'},
'Output',
max_relative_error=0.02)
max_relative_error=0.02,
numeric_place=paddle.CPUPlace())
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
paddle.NPUPlace(0), ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
no_grad_set=set(['Filter']),
numeric_place=paddle.CPUPlace())
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
paddle.NPUPlace(0), ['Filter'],
'Output',
no_grad_set=set(['Input']))
no_grad_set=set(['Input']),
numeric_place=paddle.CPUPlace())
def init_test_case(self):
self.pad = [0, 0]
......
......@@ -344,6 +344,13 @@ def conv1d(x,
l_type = 'depthwise_conv2d'
use_cudnn = False
# NPU only supports depthwise_conv2d when "input_channel = output_channel = groups"
if core.is_compiled_with_npu():
if (num_channels == groups and num_channels == num_filters):
l_type = 'depthwise_conv2d'
else:
l_type = 'conv2d'
squeeze_aixs = -2 if channel_last else -1
x = unsqueeze(x, axis=[squeeze_aixs])
weight = unsqueeze(weight, axis=[-1])
......@@ -562,6 +569,13 @@ def conv2d(x,
else:
use_cudnn = False
# NPU only supports depthwise_conv2d when "input_channel = output_channel = groups"
if core.is_compiled_with_npu():
if (num_channels == groups and num_channels == num_filters):
l_type = 'depthwise_conv2d'
else:
l_type = 'conv2d'
if (core.is_compiled_with_cuda() and get_flags("FLAGS_conv2d_disable_cudnn")
["FLAGS_conv2d_disable_cudnn"]):
use_cudnn = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册