提交 6f99bf84 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Store ksize attribute for graph transfer to SOC

Change: 137855838
上级 b00e9404
......@@ -31,6 +31,7 @@ static constexpr const char* const CONST_SHAPE_PREFIX = "const_shape_";
static constexpr const char* const PADDING_PREFIX = "NN_PAD_";
static constexpr const char* const PADDING_ATTR_NAME = "padding";
static constexpr const char* const STRIDES_ATTR_NAME = "strides";
static constexpr const char* const KSIZE_ATTR_NAME = "ksize";
static constexpr const char* const PADDING_VALID_STR = "VALID";
static constexpr const char* const PADDING_SAME_STR = "SAME";
......@@ -192,7 +193,13 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
std::vector<int32> strides;
context->GetAttr(STRIDES_ATTR_NAME, &strides);
const int stride_id = RegisterConstantShape(strides);
std::vector<int> extra_inputs{stride_id, 0};
std::vector<int> extra_inputs{stride_id};
if (node.def().attr().count(KSIZE_ATTR_NAME) > 0) {
std::vector<int32> kernel_sizes;
context->GetAttr(KSIZE_ATTR_NAME, &kernel_sizes);
const int ksize_id = RegisterConstantShape(kernel_sizes);
extra_inputs.push_back(ksize_id);
}
AppendNodeParams(node.name(), id, node.type_string(), padding,
node.num_inputs(), extra_inputs, node.num_outputs());
}
......
......@@ -58,14 +58,33 @@ static GraphDef CreateConvGraphDef() {
test::FillIota<float>(&input_data, 1.0f);
ops::Output input =
ops::Const(root.WithOpName("input"), ops::Input::Initializer(input_data));
const int stride = 1;
Tensor filter_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
test::FillIota<float>(&filter_data, 1.0f);
ops::Output filter = ops::Const(root.WithOpName("filter"),
ops::Input::Initializer(filter_data));
const std::vector<int> strides{1, 1, 1, 1};
ops::Output conv =
ops::Conv2D(root.WithOpName("conv"), input, filter, strides, "SAME");
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}
static GraphDef CreatePoolGraphDef() {
Scope root = Scope::NewRootScope();
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
test::FillIota<float>(&input_data, 1.0f);
ops::Output input =
ops::Const(root.WithOpName("input"), ops::Input::Initializer(input_data));
Tensor filter_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
test::FillIota<float>(&filter_data, 1.0f);
ops::Output filter = ops::Const(root.WithOpName("filter"),
ops::Input::Initializer(filter_data));
const std::vector<int> ksize{1, 1, 1, 1};
const std::vector<int> padding{0, 0, 0, 0};
ops::Output conv = ops::Conv2D(root.WithOpName("conv"), input, filter,
{1, stride, stride, 1}, "SAME");
const std::vector<int> strides{1, 1, 1, 1};
ops::Output max_pool =
ops::MaxPool(root.WithOpName("maxpool"), input, ksize, strides, "SAME");
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
......@@ -139,9 +158,29 @@ TEST_F(GraphTransfererTest, LoadConvGraph) {
const int id = params_conv->id;
EXPECT_TRUE(id > 0 && id <= (const_node_count + op_node_count));
EXPECT_EQ("Conv2D", params_conv->type);
EXPECT_EQ(4, params_conv->inputs_size);
EXPECT_EQ(3, params_conv->inputs_size);
EXPECT_EQ(1, params_conv->outputs_size);
EXPECT_EQ("NN_PAD_SAME", params_conv->padding);
}
TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
GraphDef def = CreatePoolGraphDef();
_session->Create(def);
GraphTransferer gt;
gt.LoadGraphFromProto(def);
const int const_node_count = gt.GetConstNodeParams().size();
ASSERT_EQ(3, const_node_count);
const int op_node_count = gt.GetOpNodeParams().size();
ASSERT_EQ(1, op_node_count);
const GraphTransferer::NodeTransferParams* params_max_pool =
FindOpNodeParams(gt, "maxpool");
ASSERT_TRUE(params_max_pool != nullptr);
const int id = params_max_pool->id;
EXPECT_TRUE(id > 0 && id <= (const_node_count + op_node_count));
EXPECT_EQ("MaxPool", params_max_pool->type);
EXPECT_EQ(3, params_max_pool->inputs_size);
EXPECT_EQ(1, params_max_pool->outputs_size);
EXPECT_EQ("NN_PAD_SAME", params_max_pool->padding);
}
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册