提交 fe77223d 编写于 作者: D Dmitry Kurtaev

Modify nGraph's ConvolutionBackpropData and GroupConvolution

上级 a2642d83
......@@ -544,6 +544,12 @@ public:
const int group = inpCn / inpGroupCn;
std::vector<size_t> kernel_shape = getShape<size_t>(blobs[0]);
if (group != 1)
{
kernel_shape[0] /= group;
kernel_shape.insert(kernel_shape.begin(), group);
}
auto ieWeights = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, kernel_shape, blobs[0].data);
if (fusedWeights)
{
......@@ -566,14 +572,12 @@ public:
std::shared_ptr<ngraph::Node> conv_node;
if (group != 1) {
conv_node = std::make_shared<ngraph::op::GroupConvolution>(
conv_node = std::make_shared<ngraph::op::v1::GroupConvolution>(
ieInpNode, ieWeights,
ngraph::Strides(strides),
ngraph::Strides(dilations),
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())),
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_end.begin(), pads_end.end())),
ngraph::Strides{},
group,
ngraph::Strides(dilations),
pad_type);
} else {
conv_node = std::make_shared<ngraph::op::v1::Convolution>(
......@@ -2037,37 +2041,29 @@ public:
Mat newWeights = blobs[0].reshape(1, inpCn);
transpose(weightsMat, newWeights);
}
size_t batch = ieInpNode->get_shape()[0];
std::vector<size_t> out_shape = {batch, (size_t)numOutput};
std::vector<size_t> paddings_end;
std::vector<size_t> inpShape = ieInpNode->get_shape();
if (padMode.empty())
{
for (int i = 0; i < pads_end.size(); i++) {
out_shape.push_back(strides[i] * (inpShape[2 + i] - 1) +
kernel_size[i] - pads_begin[i] - pads_end[i] + adjust_pads[i]);
paddings_end.push_back(pads_end[i] - adjust_pads[i]);
}
}
else if (padMode == "SAME")
{
for (int i = 0; i < pads_begin.size(); i++) {
out_shape.push_back(strides[i] * (inpShape[2 + i] - 1) + 1 + adjust_pads[i]);
paddings_end.push_back(kernel_size[i] - pads_begin[i] - 1 - adjust_pads[i]);
}
} else {
paddings_end = pads_end;
}
auto deconv = std::make_shared<ngraph::op::ConvolutionBackpropData>(
ngraph::Shape{out_shape},
ieWeights,
auto deconv = std::make_shared<ngraph::op::v1::ConvolutionBackpropData>(
ieInpNode,
ieWeights,
ngraph::Strides(strides),
ngraph::Strides(dilations),
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())),
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(paddings_end.begin(), paddings_end.end())),
(strides.size() == 2 ? ngraph::Strides{1, 1} : ngraph::Strides{1, 1, 1}));
ngraph::Strides(dilations));
if (hasBias() || fusedBias)
{
std::vector<size_t> shape(deconv->get_shape().size(), 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册