diff --git a/doc/graph_survey.md b/doc/design/graph_survey.md similarity index 68% rename from doc/graph_survey.md rename to doc/design/graph_survey.md index eec4ddb692515b55404b14eedd89802fe5577de3..6fca254495bac2c89bf54bed4c0e9dd0eed98747 100644 --- a/doc/graph_survey.md +++ b/doc/design/graph_survey.md @@ -15,7 +15,7 @@ ```python def get_symbol(num_classes=10, **kwargs): data = mx.symbol.Variable('data') - data = mx.sym.Flatten(data=data) + data = mx.symbol.Flatten(data=data) fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) @@ -119,3 +119,113 @@ Expression的主要成员为ComputationGraph,可以在用户配置网络的过 - 在用户配置网络时,所有的返回值都是Expression,包括最初的输入数据,及参数等 - Expression已经包含了所有的依赖关系,可以被当做执行的target + +下面我们来看几个实例: + +- Mxnet + + +``` +>>> import mxnet as mx +>>> data = mx.symbol.Variable('data') +>>> print data.debug_str() +Variable:data + +>>> data = mx.symbol.Flatten(data=data) +>>> print data.debug_str() +Symbol Outputs: + output[0]=flatten0(0) +Variable:data +-------------------- +Op:Flatten, Name=flatten0 +Inputs: + arg[0]=data(0) version=0 + +>>> fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) +>>> print fc1.debug_str() +Symbol Outputs: + output[0]=fc1(0) +Variable:data +-------------------- +Op:Flatten, Name=flatten0 +Inputs: + arg[0]=data(0) version=0 +Variable:fc1_weight +Variable:fc1_bias +-------------------- +Op:FullyConnected, Name=fc1 +Inputs: + arg[0]=flatten0(0) + arg[1]=fc1_weight(0) version=0 + arg[2]=fc1_bias(0) version=0 +Attrs: + num_hidden=128 + +``` + +- TensorFlow + +``` +>>> import tensorflow as tf +>>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) +>>> print c.graph + +>>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) +>>> print d.graph + +>>> e = tf.matmul(c, d) +>>> print e.graph + +``` + +没有找到Graph的debug string接口,但是可以明确知道配置过程中只存在一个Graph。 + + +- dynet + +dynet可以在C++中书写配置 + +``` +ComputationGraph cg; +Expression W = parameter(cg, pW); +cg.print_graphviz(); + +Expression pred = W * xs[i]; +cg.print_graphviz(); + +Expression loss = square(pred - ys[i]); +cg.print_graphviz(); +``` + +编译运行后,得到打印结果: + +``` +# first print +digraph G { + rankdir=LR; + nodesep=.05; + N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"]; +} +# second print +digraph G { + rankdir=LR; + nodesep=.05; + N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"]; + N1 [label="v1 = v0 * -0.98"]; + N0 -> N1; +} +# third print +digraph G { + rankdir=LR; + nodesep=.05; + N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"]; + N1 [label="v1 = v0 * -0.98"]; + N0 -> N1; + N2 [label="v2 = -1.88387 - v1"]; + N1 -> N2; + N3 [label="v3 = -v2"]; + N2 -> N3; + N4 [label="v4 = square(v3)"]; + N3 -> N4; +} +```