提交 50666a81 编写于 作者: 刘琦

Merge branch 'fix_winograd_variable_shape' into 'master'

fix winograd variable input shape

See merge request !763
......@@ -14,7 +14,7 @@ python tools/converter.py run --config=/path/to/your/model_deployment_file --exa
* Validate result
```
python tools/converter.py run --config=/path/to/your/model_deployment_file --example --example
python tools/converter.py run --config=/path/to/your/model_deployment_file --example --validate
```
* Check the logs
......
......@@ -129,12 +129,14 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
template <typename T>
MaceStatus WinogradInverseTransformFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input_tensor,
const Tensor *bias,
const std::vector<const Tensor*> &inputs,
Tensor *output_tensor,
StatsFuture *future) {
auto runtime = OpenCLRuntime::Global();
const Tensor *input_tensor = inputs[0];
const Tensor *bias = inputs.size() == 3 ? inputs[2] : nullptr;
if (kernel_.get() == nullptr) {
std::string obfuscated_kernel_name;
std::set<std::string> built_options;
......@@ -191,18 +193,23 @@ MaceStatus WinogradInverseTransformFunctor<DeviceType::GPU, T>::operator()(
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
Tensor::MappingGuard output_shape_guard(inputs[1]);
const int32_t *output_shape_data = inputs[1]->data<int32_t>();
const index_t batch = output_shape_data[0];
const index_t height = output_shape_data[1];
const index_t width = output_shape_data[2];
const uint32_t gws[2] = {
static_cast<uint32_t>(input_tensor->dim(2)),
static_cast<uint32_t>(RoundUpDiv4(input_tensor->dim(1)))};
if (!IsVecEqual(input_shape_, input_tensor->shape())) {
std::vector<index_t> output_shape = {batch_, height_, width_,
std::vector<index_t> output_shape = {batch, height, width,
input_tensor->dim(1)};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
MACE_RETURN_IF_ERROR(output_tensor->ResizeImage(output_shape, image_shape));
const index_t round_h = (height_ + wino_blk_size_ - 1) / wino_blk_size_;
const index_t round_w = (width_ + wino_blk_size_ - 1) / wino_blk_size_;
const index_t round_h = (height + wino_blk_size_ - 1) / wino_blk_size_;
const index_t round_w = (width + wino_blk_size_ - 1) / wino_blk_size_;
const float round_hw_r = 1.f / static_cast<float>(round_h * round_w);
const float round_w_r = 1.f / static_cast<float>(round_w);
......
......@@ -86,22 +86,13 @@ struct WinogradTransformFunctor<DeviceType::GPU, T>
#endif // MACE_ENABLE_OPENCL
struct WinogradInverseTransformFunctorBase {
WinogradInverseTransformFunctorBase(const int batch,
const int height,
const int width,
const ActivationType activation,
WinogradInverseTransformFunctorBase(const ActivationType activation,
const float relux_max_limit,
const int block_size)
: batch_(batch),
height_(height),
width_(width),
wino_blk_size_(block_size),
: wino_blk_size_(block_size),
activation_(activation),
relux_max_limit_(relux_max_limit) {}
const int batch_;
const int height_;
const int width_;
const int wino_blk_size_;
const ActivationType activation_;
const float relux_max_limit_;
......@@ -109,21 +100,16 @@ struct WinogradInverseTransformFunctorBase {
template<DeviceType D, typename T>
struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase {
WinogradInverseTransformFunctor(const int batch,
const int height,
const int width,
const ActivationType activation,
WinogradInverseTransformFunctor(const ActivationType activation,
const float relux_max_limit,
const int block_size)
: WinogradInverseTransformFunctorBase(
batch, height, width, activation, relux_max_limit, block_size) {}
activation, relux_max_limit, block_size) {}
MaceStatus operator()(const Tensor *input,
const Tensor *bias,
MaceStatus operator()(const std::vector<const Tensor*> &inputs,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(input);
MACE_UNUSED(bias);
MACE_UNUSED(inputs);
MACE_UNUSED(output);
MACE_UNUSED(future);
MACE_NOT_IMPLEMENTED;
......@@ -135,17 +121,13 @@ struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase {
template <typename T>
struct WinogradInverseTransformFunctor<DeviceType::GPU, T>
: WinogradInverseTransformFunctorBase {
WinogradInverseTransformFunctor(const int batch,
const int height,
const int width,
const ActivationType activation,
WinogradInverseTransformFunctor(const ActivationType activation,
const float relux_max_limit,
const int block_size)
: WinogradInverseTransformFunctorBase(
batch, height, width, activation, relux_max_limit, block_size) {}
activation, relux_max_limit, block_size) {}
MaceStatus operator()(const Tensor *input,
const Tensor *bias,
MaceStatus operator()(const std::vector<const Tensor*> &inputs,
Tensor *output,
StatsFuture *future);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/infer_conv2d_shape.h"
namespace mace {
namespace ops {
void Register_InferConv2dShape(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("InferConv2dShape")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
InferConv2dShapeOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("InferConv2dShape")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
InferConv2dShapeOp<DeviceType::CPU, int32_t>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("InferConv2dShape")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
InferConv2dShapeOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("InferConv2dShape")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
InferConv2dShapeOp<DeviceType::GPU, half>);
#endif
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_INFER_CONV2D_SHAPE_H_
#define MACE_OPS_INFER_CONV2D_SHAPE_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class InferConv2dShapeOp : public Operator<D, T> {
public:
InferConv2dShapeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4);
output->Resize({input->dim_size()});
Tensor::MappingGuard output_guard(output);
int32_t *output_data = output->mutable_data<int32_t>();
const int32_t data_format =
OperatorBase::GetOptionalArg<int>("data_format", 0);
const bool isNCHW = data_format == 1;
const Padding padding_type =
static_cast<Padding>(OperatorBase::GetOptionalArg<int>(
"padding", static_cast<int>(SAME)));
const std::vector<int32_t> paddings =
OperatorBase::GetRepeatedArgs<int32_t>("padding_values");
const std::vector<int32_t> kernels =
OperatorBase::GetRepeatedArgs<int32_t>("kernels");
const std::vector<int32_t> strides =
OperatorBase::GetRepeatedArgs<int32_t>("strides", {1, 1});
const int32_t out_batch = static_cast<int32_t>(input->dim(0));
const int32_t out_channel = static_cast<int32_t>(kernels[0]);
int32_t in_h = 0, in_w = 0, in_c = 0;
if (isNCHW) { // NCHW
in_c = static_cast<int32_t>(input->dim(1));
in_h = static_cast<int32_t>(input->dim(2));
in_w = static_cast<int32_t>(input->dim(3));
} else {
in_h = static_cast<int32_t>(input->dim(1));
in_w = static_cast<int32_t>(input->dim(2));
in_c = static_cast<int32_t>(input->dim(3));
}
MACE_CHECK(in_c == kernels[1],
"different number of input channels between input and kernel");
int32_t out_h = 0, out_w = 0;
if (!paddings.empty()) {
out_h = (in_h - kernels[2] + paddings[0]) / strides[0] + 1;
out_w = (in_w - kernels[3] + paddings[1]) / strides[1] + 1;
} else {
switch (padding_type) {
case SAME:
out_h = (in_h + strides[0] - 1) / strides[0];
out_w = (in_w + strides[1] - 1) / strides[1];
break;
case VALID:
out_h = (in_h - kernels[2] + 1) / strides[0];
out_w = (in_w - kernels[3] + 1) / strides[1];
break;
default:
MACE_NOT_IMPLEMENTED;
break;
}
}
if (isNCHW) {
output_data[0] = out_batch;
output_data[1] = out_channel;
output_data[2] = out_h;
output_data[3] = out_w;
} else {
output_data[0] = out_batch;
output_data[1] = out_h;
output_data[2] = out_w;
output_data[3] = out_channel;
}
SetFutureDefaultWaitFn(future);
return MACE_SUCCESS;
}
private:
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_INFER_CONV2D_SHAPE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace ops {
namespace test {
class InferConv2dShapeOpTest : public OpsTestBase {};
namespace {
void TestInferConv2dShapeOp(const std::vector<index_t> &input_shape,
const int stride,
const std::vector<index_t> &output_shape) {
OpsTestNet net;
net.AddRandomInput<CPU, float>("Input", input_shape);
const int in_ch = static_cast<int>(input_shape[3]);
const int out_ch = static_cast<int>(output_shape[3]);
OpDefBuilder("InferConv2dShape", "InferConv2dShapeOpTest")
.Input("Input")
.Output("Output")
.AddIntArg("datd_format", 0)
.AddIntsArg("strides", {stride, stride})
.AddIntsArg("kernels", {out_ch, in_ch, 3, 3})
.AddIntArg("padding", Padding::SAME)
.OutputType({DataTypeToEnum<int32_t>::v()})
.Finalize(net.NewOperatorDef());
net.RunOp();
std::vector<int32_t> expected_output_shape(output_shape.begin(),
output_shape.end());
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput",
{static_cast<int32_t>(
output_shape.size())},
expected_output_shape);
ExpectTensorNear<int32_t>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(InferConv2dShapeOpTest, TestInferConv2dShape) {
TestInferConv2dShapeOp({3, 640, 480, 16}, 1, {3, 640, 480, 3});
TestInferConv2dShapeOp({3, 640, 480, 16}, 2, {3, 320, 240, 3});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -39,6 +39,7 @@ extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry);
extern void Register_FullyConnected(OperatorRegistryBase *op_registry);
extern void Register_Gather(OperatorRegistryBase *op_registry);
extern void Register_Identity(OperatorRegistryBase *op_registry);
extern void Register_InferConv2dShape(OperatorRegistryBase *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistryBase *op_registry);
extern void Register_MatMul(OperatorRegistryBase *op_registry);
extern void Register_Pad(OperatorRegistryBase *op_registry);
......@@ -91,6 +92,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_FullyConnected(this);
ops::Register_Gather(this);
ops::Register_Identity(this);
ops::Register_InferConv2dShape(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
......
......@@ -43,7 +43,21 @@ void BMWinogradConvolution(
// transform filter
BufferToImage<D, T>(&net, "Filter", "WinoFilter",
kernels::BufferType::WINOGRAD_FILTER, block_size);
// transform input
// Inference convolution output shape
OpDefBuilder("InferConv2dShape", "InferConv2dShapeTest")
.Input("InputImage")
.Output("ShapeOutput")
.AddIntArg("data_format", 0)
.AddIntsArg("strides", {1, 1})
.AddIntsArg("kernels", {static_cast<int>(out_channels),
static_cast<int>(in_channels),
3, 3})
.AddIntArg("padding", Padding::SAME)
.OutputType({DataTypeToEnum<int32_t>::v()})
.Finalize(net.NewOperatorDef());
// Transform input
OpDefBuilder("WinogradTransform", "WinogradTransformTest")
.Input("InputImage")
.Output("WinoInput")
......@@ -63,6 +77,7 @@ void BMWinogradConvolution(
// Inverse transform
OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest")
.Input("WinoGemm")
.Input("ShapeOutput")
.Input("BiasImage")
.AddIntArg("batch", batch)
.AddIntArg("height", height)
......
......@@ -84,6 +84,19 @@ void WinogradConvolution(const index_t batch,
// Run on opencl
net.RunOp(D);
OpDefBuilder("InferConv2dShape", "InferConv2dShapeTest")
.Input("InputImage")
.Output("ShapeOutput")
.AddIntArg("data_format", 0)
.AddIntsArg("strides", {1, 1})
.AddIntsArg("kernels", {static_cast<int>(out_channels),
static_cast<int>(in_channels),
3, 3})
.AddIntArg("padding", padding)
.OutputType({DataTypeToEnum<int32_t>::v()})
.Finalize(net.NewOperatorDef());
net.RunOp(D);
// MatMul
OpDefBuilder("MatMul", "MatMulTest")
.Input("WinoFilter")
......@@ -97,10 +110,8 @@ void WinogradConvolution(const index_t batch,
// Inverse transform
OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest")
.Input("WinoGemm")
.Input("ShapeOutput")
.Input("BiasImage")
.AddIntArg("batch", batch)
.AddIntArg("height", output_shape[1])
.AddIntArg("width", output_shape[2])
.AddIntArg("wino_block_size", block_size)
.Output("WinoOutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
......@@ -221,6 +232,19 @@ void WinogradConvolutionWithPad(const index_t batch,
// Run on opencl
net.RunOp(D);
OpDefBuilder("InferConv2dShape", "InferConv2dShapeTest")
.Input("InputImage")
.Output("ShapeOutput")
.AddIntArg("data_format", 0)
.AddIntsArg("strides", {1, 1})
.AddIntsArg("kernels", {static_cast<int>(out_channels),
static_cast<int>(in_channels),
3, 3})
.AddIntsArg("padding_values", {padding, padding})
.OutputType({DataTypeToEnum<int32_t>::v()})
.Finalize(net.NewOperatorDef());
net.RunOp(D);
// MatMul
OpDefBuilder("MatMul", "MatMulTest")
.Input("WinoFilter")
......@@ -234,10 +258,8 @@ void WinogradConvolutionWithPad(const index_t batch,
// Inverse transform
OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest")
.Input("WinoGemm")
.Input("ShapeOutput")
.Input("BiasImage")
.AddIntArg("batch", batch)
.AddIntArg("height", output_shape[1])
.AddIntArg("width", output_shape[2])
.AddIntArg("wino_block_size", block_size)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("WinoOutputImage")
......
......@@ -17,6 +17,7 @@
#include <memory>
#include <string>
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/activation.h"
......@@ -30,27 +31,22 @@ class WinogradInverseTransformOp : public Operator<D, T> {
public:
WinogradInverseTransformOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("batch", 1),
OperatorBase::GetOptionalArg<int>("height", 0),
OperatorBase::GetOptionalArg<int>("width", 0),
kernels::StringToActivationType(
functor_(kernels::StringToActivationType(
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f),
OperatorBase::GetOptionalArg<int>("wino_block_size", 2)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
const Tensor *bias = this->InputSize() == 2 ? this->Input(BIAS) : nullptr;
const std::vector<const Tensor *> &inputs = this->Inputs();
Tensor *output_tensor = this->Output(OUTPUT);
return functor_(input_tensor, bias, output_tensor, future);
return functor_(inputs, output_tensor, future);
}
private:
kernels::WinogradInverseTransformFunctor<D, T> functor_;
protected:
MACE_OP_INPUT_TAGS(INPUT, BIAS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
......
......@@ -92,6 +92,7 @@ MaceSupportedOps = [
'FullyConnected',
'Gather',
'Identity',
'InferConv2dShape',
'LocalResponseNorm',
'MatMul',
'Pad',
......
......@@ -560,11 +560,47 @@ class Transformer(base_converter.ConverterInterface):
arg.name = MaceKeyword.mace_winograd_filter_transformed
arg.i = 1
shape_op = net.op.add()
shape_op.name = op.name + '_infer_shape'
shape_op.type = MaceOp.InferConv2dShape.name
shape_op.input.extend([op.input[0]])
shape_op.output.extend([shape_op.name])
shape_output_shape = shape_op.output_shape.add()
shape_output_shape.dims.extend([4])
kernels_arg = shape_op.arg.add()
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend([out_channels,
in_channels,
filter_height,
filter_width])
if data_format is not None:
data_format_arg = shape_op.arg.add()
data_format_arg.name = MaceKeyword.mace_data_format_str
data_format_arg.i = data_format.value
if ConverterUtil.get_arg(op,
MaceKeyword.mace_padding_str) \
is not None:
padding_arg = shape_op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_str
padding_arg.i = ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_str).i
elif ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_values_str) \
is not None:
padding_arg = shape_op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str
padding_arg.ints.extend(ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_values_str).ints)
# Inverse transform
iwt_op = net.op.add()
iwt_op.name = op.name + '_inverse_transform'
iwt_op.type = MaceOp.WinogradInverseTransform.name
iwt_op.input.extend([matmul_op.output[0]])
iwt_op.input.extend([shape_op.output[0]])
# biasadd
if len(op.input) >= 3:
iwt_op.input.extend([op.input[2]])
......@@ -572,15 +608,6 @@ class Transformer(base_converter.ConverterInterface):
iwt_output_shape = iwt_op.output_shape.add()
iwt_output_shape.dims.extend(op.output_shape[0].dims)
batch_arg = iwt_op.arg.add()
batch_arg.name = 'batch'
batch_arg.i = batch
height_arg = iwt_op.arg.add()
height_arg.name = 'height'
height_arg.i = out_height
width_arg = iwt_op.arg.add()
width_arg.name = 'width'
width_arg.i = out_width
blk_size_arg = iwt_op.arg.add()
blk_size_arg.name = MaceKeyword.mace_wino_block_size
blk_size_arg.i = block_size
......@@ -1146,7 +1173,7 @@ class Transformer(base_converter.ConverterInterface):
MaceKeyword.mace_winograd_filter_transformed) is not None: # noqa
self.buffer_to_image(op, 0, OpenCLBufferType.WINOGRAD_FILTER)
elif op.type == MaceOp.WinogradInverseTransform.name \
and len(op.input) >= 2:
and len(op.input) >= 3:
self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.FullyConnected.name:
self.buffer_to_image(op, 1, OpenCLBufferType.WEIGHT_WIDTH)
......
......@@ -228,7 +228,11 @@ class GPUMemoryOptimizer(MemoryOptimizer):
mace_pb2.GPU_IMAGE,
calculate_image_shape(OpenCLBufferType.IN_OUT_HEIGHT,
buffer_shape))
elif op_type in ['Shape', 'StridedSlice', 'Stack', 'ScalarMath']:
elif op_type in ['Shape',
'InferConv2dShape',
'StridedSlice',
'Stack',
'ScalarMath']:
if len(output_shape) == 1:
mem_block = MemoryBlock(mace_pb2.CPU_BUFFER,
[output_shape[0], 1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册