提交 736ceb5f 编写于 作者: 李寅

Refactor conv related op

上级 f16cf77e
......@@ -8,6 +8,7 @@
#include <limits>
#include <sstream>
#include <string>
#include <vector>
#undef ERROR
......@@ -41,7 +42,16 @@ template <typename... Args>
string MakeString(const Args&... args) {
std::stringstream ss;
MakeStringInternal(ss, args...);
return string(ss.str());
return ss.str();
}
template <typename T>
string MakeString(const std::vector<T> &args) {
std::stringstream ss;
for (const T& arg: args) {
ss << arg << ", ";
}
return ss.str();
}
// Specializations for already-a-string types.
......
......@@ -35,6 +35,8 @@ bool SimpleNet::Run() {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false;
}
VLOG(1) << "Op " << op->debug_def().name()
<< " has shape: " << internal::MakeString(op->Output(0)->shape());
}
return true;
}
......
......@@ -137,7 +137,7 @@ class Tensor {
alloc_->CopyBytes(raw_mutable_data(), src, size);
}
inline void DebugPrint() {
inline void DebugPrint() const {
std::stringstream os;
for (int i : shape_) {
os << i << ", ";
......
......@@ -53,6 +53,10 @@ Tensor* Workspace::GetTensor(const string& name) {
void Workspace::LoadModelTensor(const NetDef& net_def, DeviceType type) {
Serializer serializer;
for (auto& tensor_proto : net_def.tensors()) {
VLOG(1) << "Load tensor: " << tensor_proto.name()
<< " has shape: " << internal::MakeString(vector<index_t>(
tensor_proto.dims().begin(), tensor_proto.dims().end()));
tensor_map_[tensor_proto.name()] =
serializer.Deserialize(tensor_proto, type);
}
......
......@@ -12,19 +12,8 @@ namespace mace {
namespace kernels {
template<DeviceType D, typename T>
class Conv2dFunctor {
public:
Conv2dFunctor(const index_t *input_shape,
const index_t *filter_shape,
const int *strides,
const Padding padding,
const int *dilations) :
strides_(strides),
paddings_(2, 0),
dilations_(dilations) {
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data());
}
struct Conv2dFunctor {
Conv2dFunctor() {}
Conv2dFunctor(const int *strides,
const std::vector<int> &paddings,
const int *dilations) :
......@@ -112,7 +101,6 @@ class Conv2dFunctor {
}
}
private:
const int *strides_; // [stride_h, stride_w]
std::vector<int> paddings_; // [padding_h, padding_w]
const int *dilations_; // [dilation_h, dilation_w]
......
......@@ -13,18 +13,8 @@ namespace mace {
namespace kernels {
template<DeviceType D, typename T>
class DepthwiseConv2dFunctor {
public:
DepthwiseConv2dFunctor(const index_t *input_shape,
const index_t *filter_shape,
const int *strides,
const Padding padding,
const int *dilations) :
strides_(strides),
paddings_(2, 0),
dilations_(dilations) {
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data());
}
struct DepthwiseConv2dFunctor {
DepthwiseConv2dFunctor() {}
DepthwiseConv2dFunctor(const int *strides,
const std::vector<int> &paddings,
const int *dilations) :
......@@ -39,7 +29,6 @@ class DepthwiseConv2dFunctor {
const T *bias, // c_out
T *output, // NCHW
const index_t *output_shape) {
MACE_CHECK_NOTNULL(output);
index_t batch = output_shape[0];
......@@ -111,7 +100,7 @@ class DepthwiseConv2dFunctor {
}
}
}
private:
const int *strides_; // [stride_h, stride_w]
std::vector<int> paddings_; // [padding_h, padding_w]
const int *dilations_; // [dilation_h, dilation_w]
......
......@@ -11,10 +11,7 @@ namespace mace {
namespace kernels {
template <DeviceType D, typename T>
class GlobalAvgPoolingFunctor {
public:
GlobalAvgPoolingFunctor() {}
struct GlobalAvgPoolingFunctor {
void operator()(const T *input, const index_t *input_shape, T *output) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
......
......@@ -18,8 +18,7 @@ enum PoolingType {
namespace kernels {
template <DeviceType D, typename T>
class PoolingFunctor {
public:
struct PoolingFunctor {
PoolingFunctor(const PoolingType pooling_type,
const int *kernels,
const int *strides,
......@@ -114,7 +113,6 @@ class PoolingFunctor {
}
}
private:
const PoolingType pooling_type_;
const int *kernels_;
const int *strides_;
......
......@@ -6,10 +6,10 @@
namespace mace {
REGISTER_CPU_OPERATOR(Conv2d, Conv2dOp<DeviceType::CPU, float>);
REGISTER_CPU_OPERATOR(Conv2D, Conv2dOp<DeviceType::CPU, float>);
#if __ARM_NEON
REGISTER_NEON_OPERATOR(Conv2d, Conv2dOp<DeviceType::NEON, float>);
REGISTER_NEON_OPERATOR(Conv2D, Conv2dOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
} // namespace mace
......@@ -17,12 +17,10 @@ template<DeviceType D, typename T>
class Conv2dOp : public ConvPool2dOpBase<D, T> {
public:
Conv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->Input(INPUT)->shape().data(),
this->Input(FILTER)->shape().data(),
this->strides_.data(),
this->padding_,
this->dilations_.data()) {}
: ConvPool2dOpBase<D, T>(op_def, ws) {
functor_.strides_ = this->strides_.data();
functor_.dilations_ = this->dilations_.data();
}
bool Run() override {
const Tensor *input = this->Input(INPUT);
......@@ -37,8 +35,13 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
this->CalOutputSize(input->shape().data(), filter->shape().data(), output_shape.data());
kernels::CalcPaddingAndOutputSize(input->shape().data(),
filter->shape().data(),
this->dilations_.data(),
this->strides_.data(), this->padding_,
output_shape.data(), paddings.data());
output->Resize(output_shape);
functor_.paddings_ = paddings;
functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter->shape().data(), bias_data, output->mutable_data<T>(),
......
......@@ -25,7 +25,7 @@ static void Conv2d(int iters,
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("Conv2d", "Conv2dTest")
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......
......@@ -13,7 +13,7 @@ class Conv2dOpTest : public OpsTestBase {};
TEST_F(Conv2dOpTest, Simple_VALID) {
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest")
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......@@ -47,7 +47,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
TEST_F(Conv2dOpTest, Simple_SAME) {
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest")
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......@@ -83,7 +83,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
TEST_F(Conv2dOpTest, Combined) {
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest")
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......@@ -121,7 +121,7 @@ TEST_F(Conv2dOpTest, Combined) {
TEST_F(Conv2dOpTest, Conv1x1) {
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest")
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......@@ -179,7 +179,7 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
index_t output_channels = 1 + rand() % 10;
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest")
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......
......@@ -19,48 +19,7 @@ class ConvPool2dOpBase : public Operator<D, T> {
padding_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>(
"padding", static_cast<int>(SAME)))),
dilations_(OperatorBase::GetRepeatedArgument<int>("dilations", {1, 1})) {}
void CalOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
index_t *output_shape) {
MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations_[0] == 1 || strides_[0] == 1) &&
(dilations_[1] == 1 || strides_[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
/*
* Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
index_t output_height = 0, output_width = 0;
switch (padding_) {
case VALID:
output_height = (input_shape[2] - (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1;
output_width = (input_shape[3] - (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1;
break;
case SAME:
output_height = (input_shape[2] - 1) / strides_[0] + 1;
output_width = (input_shape[3] - 1) / strides_[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1;
output_width = (input_shape[3] + (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding_);
}
output_shape[0] = input_shape[0];
output_shape[1] = filter_shape[0];
output_shape[2] = output_height;
output_shape[3] = output_width;
}
protected:
std::vector<int> strides_;
Padding padding_;
......
......@@ -18,10 +18,10 @@ template<DeviceType D, typename T>
class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
public:
DepthwiseConv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->Input(INPUT)->shape().data(),
this->Input(FILTER)->shape().data(),
this->strides_.data(), this->padding_, this->dilations_.data()) {};
: ConvPool2dOpBase<D, T>(op_def, ws) {
functor_.strides_ = this->strides_.data();
functor_.dilations_ = this->dilations_.data();
}
bool Run() override {
const Tensor *input = this->Input(INPUT);
......@@ -38,8 +38,14 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
filter_shape[0] *= filter_shape[1];
filter_shape[1] = 1;
std::vector<index_t> output_shape(4);
this->CalOutputSize(input->shape().data(), filter_shape.data(), output_shape.data());
std::vector<int> paddings(2);
kernels::CalcPaddingAndOutputSize(input->shape().data(),
filter_shape.data(),
this->dilations_.data(),
this->strides_.data(), this->padding_,
output_shape.data(), paddings.data());
output->Resize(output_shape);
functor_.paddings_ = paddings;
functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter_shape.data(), bias_data, output->mutable_data<T>(),
......
......@@ -10,9 +10,10 @@ using namespace mace;
class DepthwiseConv2dOpTest : public OpsTestBase {};
TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
testing::internal::LogToStderr();
// Construct graph
auto& net = test_net();
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......@@ -35,7 +36,6 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
3.0f, 7.0f, 11.0f, 15.0f,
4.0f, 8.0f, 12.0f, 16.0f});
net.AddInputFromArray<float>("Bias", {4}, {.1f, .2f, .3f, .4f});
// Run
net.RunOp();
......@@ -61,7 +61,7 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
index_t multiplier = 3 + rand() % 10;
// Construct graph
auto& net = test_net();
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......
......@@ -25,7 +25,7 @@ static void DepthwiseConv2d(int iters,
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
OpDefBuilder("DepthwiseConv2D", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
......
......@@ -28,6 +28,7 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape(4);
// TODO(chenghui): is it kind of a hack?
filter_shape[0] = input->shape()[1];
filter_shape[1] = input->shape()[0];
filter_shape[2] = kernels_[0];
......
......@@ -23,6 +23,7 @@ def main(unused_args):
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
with gfile.GFile(FLAGS.output + '_txt', "wb") as f:
output_graph_def.ClearField('tensors')
f.write(str(output_graph_def))
......
from mace.proto import mace_pb2
import tensorflow as tf
import numpy as np
padding_mode = {
'VALID': 0,
......@@ -24,11 +25,20 @@ def convert_ops(unresolved_ops, net_def):
tf_tensor = first_op.outputs[0].eval()
tensor = net_def.tensors.add()
tensor.name = first_op.outputs[0].name
tensor.dims.extend(tf_tensor.shape)
# TODO: support other type than float
tensor.data_type = mace_pb2.DT_FLOAT
shape = list(tf_tensor.shape)
if (first_op.name.find('pointwise_kernel') != -1 or
first_op.name.find('depthwise_kernel') != -1 or
first_op.name.endswith('weights') or
first_op.name.endswith('kernel')) \
and first_op.outputs[0].consumers()[0].type.find('Conv') != -1:
tf_tensor = np.transpose(tf_tensor, axes=(3, 2, 0, 1))
shape = [shape[3], shape[2], shape[0], shape[1]]
# print (tensor.name, shape)
tensor.dims.extend(shape)
tensor.float_data.extend(tf_tensor.astype(float).flat)
# net_def.tensors.extend([tensor])
elif first_op.type == 'Conv2D' or first_op.type == 'DepthwiseConv2dNative':
op_def = net_def.op.add()
op_def.name = first_op.name
......@@ -43,10 +53,12 @@ def convert_ops(unresolved_ops, net_def):
padding_arg.i = padding_mode[first_op.get_attr('padding')]
strides_arg = op_def.arg.add()
strides_arg.name = 'strides'
strides_arg.ints.extend(first_op.get_attr('strides'))
strides_arg.ints.extend(first_op.get_attr('strides')[2:])
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = first_op.get_attr('data_format')
if first_op.get_attr('data_format') != 'NCHW':
raise Exception('only support NCHW now')
if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd':
bias_add_op = unresolved_ops[1]
......@@ -93,7 +105,7 @@ def convert_ops(unresolved_ops, net_def):
op_def.type = first_op.type
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
elif first_op.type == 'AvgPool':
elif first_op.type == 'AvgPool' or first_op.type == 'MaxPool':
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = 'Pooling'
......@@ -107,12 +119,15 @@ def convert_ops(unresolved_ops, net_def):
padding_arg.i = padding_mode[first_op.get_attr('padding')]
strides_arg = op_def.arg.add()
strides_arg.name = 'strides'
strides_arg.ints.extend(first_op.get_attr('strides')[1:-1])
strides_arg.name = 'kernels'
strides_arg.ints.extend(first_op.get_attr('ksize')[1:-1])
strides_arg.ints.extend(first_op.get_attr('strides')[2:])
kernels_arg = op_def.arg.add()
kernels_arg.name = 'kernels'
kernels_arg.ints.extend(first_op.get_attr('ksize')[2:])
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = first_op.get_attr('data_format')
if first_op.get_attr('data_format') != 'NCHW':
raise Exception('only support NCHW now')
else:
raise Exception('Unknown Op: ' + first_op.name)
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册