提交 61715ab7 编写于 作者: L liuqi

Fix resize_bilinear bug for just support int32 out dims.

上级 d4b8a5e0
......@@ -55,7 +55,9 @@ Tensor *Workspace::GetTensor(const string &name) {
void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) {
Serializer serializer;
for (auto &tensor_proto : net_def.tensors()) {
VLOG(1) << "Load tensor: " << tensor_proto.name() << " has shape: "
VLOG(1) << "Load tensor: " << tensor_proto.name()
<< ", with data type: " << tensor_proto.data_type()
<< ", has shape: "
<< internal::MakeString(vector<index_t>(tensor_proto.dims().begin(),
tensor_proto.dims().end()));
tensor_map_[tensor_proto.name()] =
......
......@@ -81,30 +81,36 @@ int main(int argc, char **argv) {
net_def.ParseFromIstream(&file_stream);
file_stream.close();
DeviceType device_type;
DeviceType_Parse(device, &device_type);
VLOG(0) << device_type;
Workspace ws;
ws.LoadModelTensor(net_def, DeviceType::CPU);
ws.LoadModelTensor(net_def, device_type);
Tensor *input_tensor =
ws.CreateTensor(input_node + ":0", GetDeviceAllocator(DeviceType::CPU), DT_FLOAT);
ws.CreateTensor(input_node + ":0", GetDeviceAllocator(device_type), DT_FLOAT);
input_tensor->Resize(shape);
float *input_data = input_tensor->mutable_data<float>();
{
Tensor::MappingGuard input_guard(input_tensor);
float *input_data = input_tensor->mutable_data<float>();
// load input
ifstream in_file(input_file, ios::in | ios::binary);
in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
}
// load input
ifstream in_file(input_file, ios::in | ios::binary);
in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
// run model
DeviceType device_type;
DeviceType_Parse(device, &device_type);
VLOG(0) << device_type;
auto net = CreateNet(net_def, &ws, device_type);
VLOG(0) << "warm up";
// warm up
for (int i = 0; i < 2; ++i) {
for (int i = 0; i < 1; ++i) {
net->Run();
}
VLOG(0) << "run";
timeval tv1, tv2;
gettimeofday(&tv1, NULL);
for (int i = 0; i < round; ++i) {
......@@ -120,9 +126,15 @@ int main(int argc, char **argv) {
// save output
const Tensor *output = ws.GetTensor(output_node + ":0");
Tensor::MappingGuard output_guard(output);
ofstream out_file(output_file, ios::binary);
out_file.write((const char *)(output->data<float>()),
output->size() * sizeof(float));
out_file.flush();
out_file.close();
VLOG(0) << "Output shape: ["
<< output->dim(0) << ", "
<< output->dim(1) << ", "
<< output->dim(2) << ", "
<< output->dim(3) << "]";
}
\ No newline at end of file
......@@ -17,7 +17,7 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
......@@ -39,20 +39,16 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1;
switch (padding) {
case VALID:
output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
case VALID:output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[3] - k_extent_width) / strides[1] + 1;
break;
case SAME:
output_height = (input_shape[2] - 1) / strides[0] + 1;
case SAME:output_height = (input_shape[2] - 1) / strides[0] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
case FULL:output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
default:MACE_CHECK(false, "Unsupported padding type: ", padding);
}
// Note: TensorFlow may padded one more on the right/bottom side
......@@ -61,10 +57,10 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
// based on the model accuracy.
padding_size[0] =
std::max<int>(0, (output_height - 1) * strides[0]
std::max<int>(0, (output_height - 1) * strides[0]
+ k_extent_height - input_shape[2]);
padding_size[1] =
std::max<int>(0, (output_width - 1) * strides[1]
std::max<int>(0, (output_width - 1) * strides[1]
+ k_extent_width - input_shape[3]);
output_shape[0] = input_shape[0];
......@@ -82,7 +78,7 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(padding_size);
......@@ -91,20 +87,16 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
index_t k_extent_width = (filter_shape[3] - 1) * dilations[1] + 1;
switch (padding) {
case VALID:
output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
case VALID:output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[3] - k_extent_width) / strides[1] + 1;
break;
case SAME:
output_height = (input_shape[2] - 1) / strides[0] + 1;
case SAME:output_height = (input_shape[2] - 1) / strides[0] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
case FULL:output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
default:MACE_CHECK(false, "Unsupported padding type: ", padding);
}
// Note: TensorFlow may padded one more on the right/bottom side
......@@ -112,10 +104,10 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
padding_size[0] =
std::max<int>(0, (output_height - 1) * strides[0]
std::max<int>(0, (output_height - 1) * strides[0]
+ k_extent_height - input_shape[2]);
padding_size[1] =
std::max<int>(0, (output_width - 1) * strides[1]
std::max<int>(0, (output_width - 1) * strides[1]
+ k_extent_width - input_shape[3]);
}
......@@ -123,6 +115,7 @@ void ConstructInputWithPadding(const Tensor *input_tensor,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value) {
VLOG(1) << "input: " << input_tensor->NumElements();
Tensor::MappingGuard input_mapper(input_tensor);
const float *input = input_tensor->data<float>();
const index_t *input_shape = input_tensor->shape().data();
......
......@@ -49,11 +49,13 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
int out_chan_len = out_chan_end - out_chan_begin;
int pixel_len = out_pixel_end - out_pixel_begin;
for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) {
float *output_ptr = output_base + out_chan * pixel_num;
float bias_value = bias[out_chan];
for (int p = 0; p < pixel_len; ++p) {
output_ptr[p] = bias_value;
if (bias != NULL) {
for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) {
float *output_ptr = output_base + out_chan * pixel_num;
float bias_value = bias[out_chan];
for (int p = 0; p < pixel_len; ++p) {
output_ptr[p] = bias_value;
}
}
}
......
......@@ -39,7 +39,8 @@ void kernel conv_2d_3x3(global const float *input,
float *output_ptr = output_base + i * out_pixel;
const float *filter_base = filter + i * in_chan_num * 9;
if (pixels == 4) {
float4 res = (float4)bias[i];
float4 res = bias == NULL ? 0 : (float4)bias[i];
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
......@@ -56,7 +57,7 @@ void kernel conv_2d_3x3(global const float *input,
vstore4(res, 0, output_ptr);
} else {
for (int p = 0; p < pixels; ++p) {
float res = bias[i];
float res = bias == NULL ? 0 : bias[i];
for (uint in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel + p * stride_w;
const float* filter_ptr = filter_base + in_chan_idx * 9;
......
......@@ -68,8 +68,12 @@ void Conv1x1V2(const Tensor *input,
*(static_cast<const cl::Buffer *>(input->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Buffer *>(filter->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Buffer *>(bias->buffer())));
if (bias == NULL) {
conv_2d_kernel.setArg(idx++, NULL);
} else {
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Buffer *>(bias->buffer())));
}
conv_2d_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channels));
conv_2d_kernel.setArg(idx++, static_cast<int>(channels));
......
......@@ -27,7 +27,11 @@ static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter,
uint32_t idx = 0;
conv_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
conv_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(filter->buffer())));
conv_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(bias->buffer())));
if (bias == nullptr) {
conv_kernel.setArg(idx++, NULL);
} else {
conv_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(bias->buffer())));
}
conv_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
conv_kernel.setArg(idx++, static_cast<int32_t>(input->dim(1)));
conv_kernel.setArg(idx++, static_cast<int32_t>(channels));
......
......@@ -36,12 +36,13 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, float>::operator()(
uint32_t idx = 0;
rb_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
rb_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
rb_kernel.setArg(idx++, static_cast<float>(height_scale));
rb_kernel.setArg(idx++, static_cast<float>(width_scale));
rb_kernel.setArg(idx++, height_scale);
rb_kernel.setArg(idx++, width_scale);
rb_kernel.setArg(idx++, static_cast<int>(in_height));
rb_kernel.setArg(idx++, static_cast<int>(in_width));
auto command_queue = runtime->command_queue();
cl_int error = command_queue.enqueueNDRangeKernel(
rb_kernel, cl::NullRange,
cl::NDRange(static_cast<int>(batch * channels),
......
......@@ -154,7 +154,7 @@ class ResizeBilinearFunctor {
if (size_[0] < 0 || size_[1] < 0) {
MACE_CHECK(resize_dims != nullptr && resize_dims->dim_size() == 1);
Tensor::MappingGuard resize_dims_mapper(resize_dims);
auto dims_data = resize_dims->data<index_t>();
auto dims_data = resize_dims->data<int32_t>();
*out_height = dims_data[0];
*out_width = dims_data[1];
} else {
......
......@@ -9,7 +9,7 @@ using namespace mace;
class Conv2dOpTest : public OpsTestBase {};
template <DeviceType D>
template<DeviceType D>
void TestSimple3x3VALID() {
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
......@@ -44,7 +44,7 @@ void TestSimple3x3VALID() {
}
template <DeviceType D>
template<DeviceType D>
void TestSimple3x3SAME() {
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
......@@ -93,7 +93,51 @@ TEST_F(Conv2dOpTest, OPENCLSimple) {
TestSimple3x3SAME<DeviceType::OPENCL>();
}
template <DeviceType D>
template<DeviceType D>
void TestSimple3x3WithoutBias() {
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add args
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 2, 3, 3},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>(
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
// Run
net.RunOp(D);
// Check
auto expected = CreateTensor<float>({1, 1, 1, 1}, {18.0f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(Conv2dOpTest, CPUWithoutBias) {
TestSimple3x3WithoutBias<DeviceType::CPU>();
}
TEST_F(Conv2dOpTest, NEONWithouBias) {
TestSimple3x3WithoutBias<DeviceType::NEON>();
}
TEST_F(Conv2dOpTest, OPENCLWithoutBias) {
TestSimple3x3WithoutBias<DeviceType::OPENCL>();
}
template<DeviceType D>
static void TestCombined3x3() {
// Construct graph
OpsTestNet net;
......@@ -143,7 +187,7 @@ TEST_F(Conv2dOpTest, OPENCLCombined) {
TestCombined3x3<DeviceType::OPENCL>();
}
template <DeviceType D>
template<DeviceType D>
void TestConv1x1() {
// Construct graph
OpsTestNet net;
......@@ -178,9 +222,9 @@ void TestConv1x1() {
// Check
auto expected = CreateTensor<float>(
{1, 2, 3, 10},
{5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
{5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f});
......@@ -196,7 +240,7 @@ TEST_F(Conv2dOpTest, OPENCLConv1x1) {
TestConv1x1<DeviceType::OPENCL>();
}
template <DeviceType D>
template<DeviceType D>
static void TestAlignedConvNxNS12() {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
......@@ -254,7 +298,7 @@ TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) {
TestAlignedConvNxNS12<DeviceType::OPENCL>();
}
template <DeviceType D>
template<DeviceType D>
static void TestUnalignedConvNxNS12() {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
......
......@@ -24,7 +24,7 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) {
vector<float> input(24);
std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 2, 4}, input);
net.AddInputFromArray<DeviceType::CPU, index_t>("OutSize", {2}, {1, 2});
net.AddInputFromArray<DeviceType::CPU, int>("OutSize", {2}, {1, 2});
// Run
net.RunOp();
......@@ -50,7 +50,7 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
vector<float> input(24);
std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 2, 4}, input);
net.AddInputFromArray<DeviceType::CPU, index_t>("OutSize", {2}, {1, 2});
net.AddInputFromArray<DeviceType::CPU, int>("OutSize", {2}, {1, 2});
// Run
net.RunOp();
......@@ -86,7 +86,7 @@ void TestRandomResizeBilinear() {
// Add input data
net.AddRandomInput<D, float>("Input",
{batch, channels, in_height, in_width});
net.AddInputFromArray<D, index_t>("OutSize", {2}, {height, width});
net.AddInputFromArray<D, int>("OutSize", {2}, {height, width});
// Run
net.RunOp(D);
......
......@@ -2,6 +2,7 @@ from mace.proto import mace_pb2
import tensorflow as tf
import numpy as np
# TODO: support NCHW formt, now only support NHWC.
padding_mode = {
'VALID': 0,
'SAME': 1,
......@@ -22,7 +23,7 @@ def convert_tensor(op, tensor):
op.name.endswith('weights') or
op.name.endswith('kernel')) \
and op.outputs[0].consumers()[0].type.find('Conv') != -1:
if op.outputs[0].consumers()[0].get_attr('data_format') == 'NCHW':
if op.outputs[0].consumers()[0].get_attr('data_format') == 'NHWC':
tf_tensor = np.transpose(tf_tensor, axes=(3, 2, 0, 1))
shape = [shape[3], shape[2], shape[0], shape[1]]
# print (tensor.name, shape)
......@@ -70,7 +71,7 @@ def convert_ops(unresolved_ops, net_def):
padding_arg.i = padding_mode[first_op.get_attr('padding')]
strides_arg = op_def.arg.add()
strides_arg.name = 'strides'
strides_arg.ints.extend(first_op.get_attr('strides')[2:])
strides_arg.ints.extend(first_op.get_attr('strides')[1:3])
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = 'NCHW'
......@@ -129,10 +130,10 @@ def convert_ops(unresolved_ops, net_def):
padding_arg.i = padding_mode[first_op.get_attr('padding')]
strides_arg = op_def.arg.add()
strides_arg.name = 'strides'
strides_arg.ints.extend(first_op.get_attr('strides')[2:])
strides_arg.ints.extend(first_op.get_attr('strides')[1:3])
kernels_arg = op_def.arg.add()
kernels_arg.name = 'kernels'
kernels_arg.ints.extend(first_op.get_attr('ksize')[2:])
kernels_arg.ints.extend(first_op.get_attr('ksize')[1:3])
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = 'NCHW'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册