提交 736ceb5f 编写于 作者: 李寅

Refactor conv related op

上级 f16cf77e
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <limits> #include <limits>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector>
#undef ERROR #undef ERROR
...@@ -41,7 +42,16 @@ template <typename... Args> ...@@ -41,7 +42,16 @@ template <typename... Args>
string MakeString(const Args&... args) { string MakeString(const Args&... args) {
std::stringstream ss; std::stringstream ss;
MakeStringInternal(ss, args...); MakeStringInternal(ss, args...);
return string(ss.str()); return ss.str();
}
template <typename T>
string MakeString(const std::vector<T> &args) {
std::stringstream ss;
for (const T& arg: args) {
ss << arg << ", ";
}
return ss.str();
} }
// Specializations for already-a-string types. // Specializations for already-a-string types.
......
...@@ -35,6 +35,8 @@ bool SimpleNet::Run() { ...@@ -35,6 +35,8 @@ bool SimpleNet::Run() {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false; return false;
} }
VLOG(1) << "Op " << op->debug_def().name()
<< " has shape: " << internal::MakeString(op->Output(0)->shape());
} }
return true; return true;
} }
......
...@@ -137,7 +137,7 @@ class Tensor { ...@@ -137,7 +137,7 @@ class Tensor {
alloc_->CopyBytes(raw_mutable_data(), src, size); alloc_->CopyBytes(raw_mutable_data(), src, size);
} }
inline void DebugPrint() { inline void DebugPrint() const {
std::stringstream os; std::stringstream os;
for (int i : shape_) { for (int i : shape_) {
os << i << ", "; os << i << ", ";
......
...@@ -53,6 +53,10 @@ Tensor* Workspace::GetTensor(const string& name) { ...@@ -53,6 +53,10 @@ Tensor* Workspace::GetTensor(const string& name) {
void Workspace::LoadModelTensor(const NetDef& net_def, DeviceType type) { void Workspace::LoadModelTensor(const NetDef& net_def, DeviceType type) {
Serializer serializer; Serializer serializer;
for (auto& tensor_proto : net_def.tensors()) { for (auto& tensor_proto : net_def.tensors()) {
VLOG(1) << "Load tensor: " << tensor_proto.name()
<< " has shape: " << internal::MakeString(vector<index_t>(
tensor_proto.dims().begin(), tensor_proto.dims().end()));
tensor_map_[tensor_proto.name()] = tensor_map_[tensor_proto.name()] =
serializer.Deserialize(tensor_proto, type); serializer.Deserialize(tensor_proto, type);
} }
......
...@@ -12,19 +12,8 @@ namespace mace { ...@@ -12,19 +12,8 @@ namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template<DeviceType D, typename T>
class Conv2dFunctor { struct Conv2dFunctor {
public: Conv2dFunctor() {}
Conv2dFunctor(const index_t *input_shape,
const index_t *filter_shape,
const int *strides,
const Padding padding,
const int *dilations) :
strides_(strides),
paddings_(2, 0),
dilations_(dilations) {
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data());
}
Conv2dFunctor(const int *strides, Conv2dFunctor(const int *strides,
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations) : const int *dilations) :
...@@ -112,7 +101,6 @@ class Conv2dFunctor { ...@@ -112,7 +101,6 @@ class Conv2dFunctor {
} }
} }
private:
const int *strides_; // [stride_h, stride_w] const int *strides_; // [stride_h, stride_w]
std::vector<int> paddings_; // [padding_h, padding_w] std::vector<int> paddings_; // [padding_h, padding_w]
const int *dilations_; // [dilation_h, dilation_w] const int *dilations_; // [dilation_h, dilation_w]
......
...@@ -13,18 +13,8 @@ namespace mace { ...@@ -13,18 +13,8 @@ namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template<DeviceType D, typename T>
class DepthwiseConv2dFunctor { struct DepthwiseConv2dFunctor {
public: DepthwiseConv2dFunctor() {}
DepthwiseConv2dFunctor(const index_t *input_shape,
const index_t *filter_shape,
const int *strides,
const Padding padding,
const int *dilations) :
strides_(strides),
paddings_(2, 0),
dilations_(dilations) {
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data());
}
DepthwiseConv2dFunctor(const int *strides, DepthwiseConv2dFunctor(const int *strides,
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations) : const int *dilations) :
...@@ -39,7 +29,6 @@ class DepthwiseConv2dFunctor { ...@@ -39,7 +29,6 @@ class DepthwiseConv2dFunctor {
const T *bias, // c_out const T *bias, // c_out
T *output, // NCHW T *output, // NCHW
const index_t *output_shape) { const index_t *output_shape) {
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
index_t batch = output_shape[0]; index_t batch = output_shape[0];
...@@ -111,7 +100,7 @@ class DepthwiseConv2dFunctor { ...@@ -111,7 +100,7 @@ class DepthwiseConv2dFunctor {
} }
} }
} }
private:
const int *strides_; // [stride_h, stride_w] const int *strides_; // [stride_h, stride_w]
std::vector<int> paddings_; // [padding_h, padding_w] std::vector<int> paddings_; // [padding_h, padding_w]
const int *dilations_; // [dilation_h, dilation_w] const int *dilations_; // [dilation_h, dilation_w]
......
...@@ -11,10 +11,7 @@ namespace mace { ...@@ -11,10 +11,7 @@ namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template <DeviceType D, typename T>
class GlobalAvgPoolingFunctor { struct GlobalAvgPoolingFunctor {
public:
GlobalAvgPoolingFunctor() {}
void operator()(const T *input, const index_t *input_shape, T *output) { void operator()(const T *input, const index_t *input_shape, T *output) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
index_t channels = input_shape[1]; index_t channels = input_shape[1];
......
...@@ -18,8 +18,7 @@ enum PoolingType { ...@@ -18,8 +18,7 @@ enum PoolingType {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template <DeviceType D, typename T>
class PoolingFunctor { struct PoolingFunctor {
public:
PoolingFunctor(const PoolingType pooling_type, PoolingFunctor(const PoolingType pooling_type,
const int *kernels, const int *kernels,
const int *strides, const int *strides,
...@@ -114,7 +113,6 @@ class PoolingFunctor { ...@@ -114,7 +113,6 @@ class PoolingFunctor {
} }
} }
private:
const PoolingType pooling_type_; const PoolingType pooling_type_;
const int *kernels_; const int *kernels_;
const int *strides_; const int *strides_;
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
namespace mace { namespace mace {
REGISTER_CPU_OPERATOR(Conv2d, Conv2dOp<DeviceType::CPU, float>); REGISTER_CPU_OPERATOR(Conv2D, Conv2dOp<DeviceType::CPU, float>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(Conv2d, Conv2dOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(Conv2D, Conv2dOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
} // namespace mace } // namespace mace
...@@ -17,12 +17,10 @@ template<DeviceType D, typename T> ...@@ -17,12 +17,10 @@ template<DeviceType D, typename T>
class Conv2dOp : public ConvPool2dOpBase<D, T> { class Conv2dOp : public ConvPool2dOpBase<D, T> {
public: public:
Conv2dOp(const OperatorDef &op_def, Workspace *ws) Conv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws), : ConvPool2dOpBase<D, T>(op_def, ws) {
functor_(this->Input(INPUT)->shape().data(), functor_.strides_ = this->strides_.data();
this->Input(FILTER)->shape().data(), functor_.dilations_ = this->dilations_.data();
this->strides_.data(), }
this->padding_,
this->dilations_.data()) {}
bool Run() override { bool Run() override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
...@@ -37,8 +35,13 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -37,8 +35,13 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
this->CalOutputSize(input->shape().data(), filter->shape().data(), output_shape.data()); kernels::CalcPaddingAndOutputSize(input->shape().data(),
filter->shape().data(),
this->dilations_.data(),
this->strides_.data(), this->padding_,
output_shape.data(), paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
functor_.paddings_ = paddings;
functor_(input->data<T>(), input->shape().data(), filter->data<T>(), functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter->shape().data(), bias_data, output->mutable_data<T>(), filter->shape().data(), bias_data, output->mutable_data<T>(),
......
...@@ -25,7 +25,7 @@ static void Conv2d(int iters, ...@@ -25,7 +25,7 @@ static void Conv2d(int iters,
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
......
...@@ -13,7 +13,7 @@ class Conv2dOpTest : public OpsTestBase {}; ...@@ -13,7 +13,7 @@ class Conv2dOpTest : public OpsTestBase {};
TEST_F(Conv2dOpTest, Simple_VALID) { TEST_F(Conv2dOpTest, Simple_VALID) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
...@@ -47,7 +47,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) { ...@@ -47,7 +47,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
TEST_F(Conv2dOpTest, Simple_SAME) { TEST_F(Conv2dOpTest, Simple_SAME) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
...@@ -83,7 +83,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) { ...@@ -83,7 +83,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
TEST_F(Conv2dOpTest, Combined) { TEST_F(Conv2dOpTest, Combined) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
...@@ -121,7 +121,7 @@ TEST_F(Conv2dOpTest, Combined) { ...@@ -121,7 +121,7 @@ TEST_F(Conv2dOpTest, Combined) {
TEST_F(Conv2dOpTest, Conv1x1) { TEST_F(Conv2dOpTest, Conv1x1) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
...@@ -179,7 +179,7 @@ TEST_F(Conv2dOpTest, ConvNxNS12) { ...@@ -179,7 +179,7 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
index_t output_channels = 1 + rand() % 10; index_t output_channels = 1 + rand() % 10;
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
......
...@@ -19,48 +19,7 @@ class ConvPool2dOpBase : public Operator<D, T> { ...@@ -19,48 +19,7 @@ class ConvPool2dOpBase : public Operator<D, T> {
padding_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>( padding_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>(
"padding", static_cast<int>(SAME)))), "padding", static_cast<int>(SAME)))),
dilations_(OperatorBase::GetRepeatedArgument<int>("dilations", {1, 1})) {} dilations_(OperatorBase::GetRepeatedArgument<int>("dilations", {1, 1})) {}
void CalOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
index_t *output_shape) {
MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations_[0] == 1 || strides_[0] == 1) &&
(dilations_[1] == 1 || strides_[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
/*
* Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
index_t output_height = 0, output_width = 0;
switch (padding_) {
case VALID:
output_height = (input_shape[2] - (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1;
output_width = (input_shape[3] - (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1;
break;
case SAME:
output_height = (input_shape[2] - 1) / strides_[0] + 1;
output_width = (input_shape[3] - 1) / strides_[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1;
output_width = (input_shape[3] + (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding_);
}
output_shape[0] = input_shape[0];
output_shape[1] = filter_shape[0];
output_shape[2] = output_height;
output_shape[3] = output_width;
}
protected: protected:
std::vector<int> strides_; std::vector<int> strides_;
Padding padding_; Padding padding_;
......
...@@ -18,10 +18,10 @@ template<DeviceType D, typename T> ...@@ -18,10 +18,10 @@ template<DeviceType D, typename T>
class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> { class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
public: public:
DepthwiseConv2dOp(const OperatorDef &op_def, Workspace *ws) DepthwiseConv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws), : ConvPool2dOpBase<D, T>(op_def, ws) {
functor_(this->Input(INPUT)->shape().data(), functor_.strides_ = this->strides_.data();
this->Input(FILTER)->shape().data(), functor_.dilations_ = this->dilations_.data();
this->strides_.data(), this->padding_, this->dilations_.data()) {}; }
bool Run() override { bool Run() override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
...@@ -38,8 +38,14 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -38,8 +38,14 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
filter_shape[0] *= filter_shape[1]; filter_shape[0] *= filter_shape[1];
filter_shape[1] = 1; filter_shape[1] = 1;
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
this->CalOutputSize(input->shape().data(), filter_shape.data(), output_shape.data()); std::vector<int> paddings(2);
kernels::CalcPaddingAndOutputSize(input->shape().data(),
filter_shape.data(),
this->dilations_.data(),
this->strides_.data(), this->padding_,
output_shape.data(), paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
functor_.paddings_ = paddings;
functor_(input->data<T>(), input->shape().data(), filter->data<T>(), functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter_shape.data(), bias_data, output->mutable_data<T>(), filter_shape.data(), bias_data, output->mutable_data<T>(),
......
...@@ -10,9 +10,10 @@ using namespace mace; ...@@ -10,9 +10,10 @@ using namespace mace;
class DepthwiseConv2dOpTest : public OpsTestBase {}; class DepthwiseConv2dOpTest : public OpsTestBase {};
TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
testing::internal::LogToStderr();
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
...@@ -35,7 +36,6 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { ...@@ -35,7 +36,6 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
3.0f, 7.0f, 11.0f, 15.0f, 3.0f, 7.0f, 11.0f, 15.0f,
4.0f, 8.0f, 12.0f, 16.0f}); 4.0f, 8.0f, 12.0f, 16.0f});
net.AddInputFromArray<float>("Bias", {4}, {.1f, .2f, .3f, .4f}); net.AddInputFromArray<float>("Bias", {4}, {.1f, .2f, .3f, .4f});
// Run // Run
net.RunOp(); net.RunOp();
...@@ -61,7 +61,7 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) { ...@@ -61,7 +61,7 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
index_t multiplier = 3 + rand() % 10; index_t multiplier = 3 + rand() % 10;
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
......
...@@ -25,7 +25,7 @@ static void DepthwiseConv2d(int iters, ...@@ -25,7 +25,7 @@ static void DepthwiseConv2d(int iters,
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest") OpDefBuilder("DepthwiseConv2D", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
......
...@@ -28,6 +28,7 @@ class PoolingOp : public ConvPool2dOpBase<D, T> { ...@@ -28,6 +28,7 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
std::vector<index_t> filter_shape(4); std::vector<index_t> filter_shape(4);
// TODO(chenghui): is it kind of a hack?
filter_shape[0] = input->shape()[1]; filter_shape[0] = input->shape()[1];
filter_shape[1] = input->shape()[0]; filter_shape[1] = input->shape()[0];
filter_shape[2] = kernels_[0]; filter_shape[2] = kernels_[0];
......
...@@ -23,6 +23,7 @@ def main(unused_args): ...@@ -23,6 +23,7 @@ def main(unused_args):
with gfile.GFile(FLAGS.output, "wb") as f: with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString()) f.write(output_graph_def.SerializeToString())
with gfile.GFile(FLAGS.output + '_txt', "wb") as f: with gfile.GFile(FLAGS.output + '_txt', "wb") as f:
output_graph_def.ClearField('tensors')
f.write(str(output_graph_def)) f.write(str(output_graph_def))
......
from mace.proto import mace_pb2 from mace.proto import mace_pb2
import tensorflow as tf import tensorflow as tf
import numpy as np
padding_mode = { padding_mode = {
'VALID': 0, 'VALID': 0,
...@@ -24,11 +25,20 @@ def convert_ops(unresolved_ops, net_def): ...@@ -24,11 +25,20 @@ def convert_ops(unresolved_ops, net_def):
tf_tensor = first_op.outputs[0].eval() tf_tensor = first_op.outputs[0].eval()
tensor = net_def.tensors.add() tensor = net_def.tensors.add()
tensor.name = first_op.outputs[0].name tensor.name = first_op.outputs[0].name
tensor.dims.extend(tf_tensor.shape)
# TODO: support other type than float # TODO: support other type than float
tensor.data_type = mace_pb2.DT_FLOAT tensor.data_type = mace_pb2.DT_FLOAT
shape = list(tf_tensor.shape)
if (first_op.name.find('pointwise_kernel') != -1 or
first_op.name.find('depthwise_kernel') != -1 or
first_op.name.endswith('weights') or
first_op.name.endswith('kernel')) \
and first_op.outputs[0].consumers()[0].type.find('Conv') != -1:
tf_tensor = np.transpose(tf_tensor, axes=(3, 2, 0, 1))
shape = [shape[3], shape[2], shape[0], shape[1]]
# print (tensor.name, shape)
tensor.dims.extend(shape)
tensor.float_data.extend(tf_tensor.astype(float).flat) tensor.float_data.extend(tf_tensor.astype(float).flat)
# net_def.tensors.extend([tensor])
elif first_op.type == 'Conv2D' or first_op.type == 'DepthwiseConv2dNative': elif first_op.type == 'Conv2D' or first_op.type == 'DepthwiseConv2dNative':
op_def = net_def.op.add() op_def = net_def.op.add()
op_def.name = first_op.name op_def.name = first_op.name
...@@ -43,10 +53,12 @@ def convert_ops(unresolved_ops, net_def): ...@@ -43,10 +53,12 @@ def convert_ops(unresolved_ops, net_def):
padding_arg.i = padding_mode[first_op.get_attr('padding')] padding_arg.i = padding_mode[first_op.get_attr('padding')]
strides_arg = op_def.arg.add() strides_arg = op_def.arg.add()
strides_arg.name = 'strides' strides_arg.name = 'strides'
strides_arg.ints.extend(first_op.get_attr('strides')) strides_arg.ints.extend(first_op.get_attr('strides')[2:])
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
data_format_arg.s = first_op.get_attr('data_format') data_format_arg.s = first_op.get_attr('data_format')
if first_op.get_attr('data_format') != 'NCHW':
raise Exception('only support NCHW now')
if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd': if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd':
bias_add_op = unresolved_ops[1] bias_add_op = unresolved_ops[1]
...@@ -93,7 +105,7 @@ def convert_ops(unresolved_ops, net_def): ...@@ -93,7 +105,7 @@ def convert_ops(unresolved_ops, net_def):
op_def.type = first_op.type op_def.type = first_op.type
op_def.input.extend([input.name for input in first_op.inputs]) op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs]) op_def.output.extend([output.name for output in first_op.outputs])
elif first_op.type == 'AvgPool': elif first_op.type == 'AvgPool' or first_op.type == 'MaxPool':
op_def = net_def.op.add() op_def = net_def.op.add()
op_def.name = first_op.name op_def.name = first_op.name
op_def.type = 'Pooling' op_def.type = 'Pooling'
...@@ -107,12 +119,15 @@ def convert_ops(unresolved_ops, net_def): ...@@ -107,12 +119,15 @@ def convert_ops(unresolved_ops, net_def):
padding_arg.i = padding_mode[first_op.get_attr('padding')] padding_arg.i = padding_mode[first_op.get_attr('padding')]
strides_arg = op_def.arg.add() strides_arg = op_def.arg.add()
strides_arg.name = 'strides' strides_arg.name = 'strides'
strides_arg.ints.extend(first_op.get_attr('strides')[1:-1]) strides_arg.ints.extend(first_op.get_attr('strides')[2:])
strides_arg.name = 'kernels' kernels_arg = op_def.arg.add()
strides_arg.ints.extend(first_op.get_attr('ksize')[1:-1]) kernels_arg.name = 'kernels'
kernels_arg.ints.extend(first_op.get_attr('ksize')[2:])
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
data_format_arg.s = first_op.get_attr('data_format') data_format_arg.s = first_op.get_attr('data_format')
if first_op.get_attr('data_format') != 'NCHW':
raise Exception('only support NCHW now')
else: else:
raise Exception('Unknown Op: ' + first_op.name) raise Exception('Unknown Op: ' + first_op.name)
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册