提交 bdac9501 编写于 作者: C ceci3 提交者: whs

fix nas print best_tokens (#9)

上级 2a186e9c
...@@ -81,7 +81,7 @@ op_type,active_type,n_in,c_in,h_in,w_in\tlatency ...@@ -81,7 +81,7 @@ op_type,active_type,n_in,c_in,h_in,w_in\tlatency
**字段解释** **字段解释**
- **op_type(str)** - 当前op类型。 - **op_type(str)** - 当前op类型。
- **active_type (string)** - 激活函数类型,包含:relu, prelu, sigmoid, relu6, tanh。 - **active_type (string|None)** - 激活函数类型,包含:relu, prelu, sigmoid, relu6, tanh。
- **n_in (int)** - 输入 Tensor 的批尺寸 (batch size)。 - **n_in (int)** - 输入 Tensor 的批尺寸 (batch size)。
- **c_in (int)** - 输入 Tensor 的通道 (channel) 数。 - **c_in (int)** - 输入 Tensor 的通道 (channel) 数。
- **h_in (int)** - 输入 Tensor 的特征高度。 - **h_in (int)** - 输入 Tensor 的特征高度。
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid import Program
from ..core import GraphWrapper, OpWrapper
__all__ = ["LatencyEvaluator", "TableLatencyEvaluator"] __all__ = ["LatencyEvaluator", "TableLatencyEvaluator"]
...@@ -28,33 +30,33 @@ class LatencyEvaluator(object): ...@@ -28,33 +30,33 @@ class LatencyEvaluator(object):
i = 0 i = 0
for op in graph.ops(): for op in graph.ops():
if op.type() in ['conv2d', 'depthwise_conv2d']: if op.type() in ['conv2d', 'depthwise_conv2d']:
tmp = _conv_op_args(op) tmp = self._conv_op_args(op)
elif op.type() in [ elif op.type() in [
'elementwise_add', 'elementwise_mul', 'elementwise_max' 'elementwise_add', 'elementwise_mul', 'elementwise_max'
]: ]:
tmp = _eltwise_op_args(op) tmp = self._eltwise_op_args(op)
elif op.type() in [ elif op.type() in [
'relu', 'prelu', 'sigmoid', 'relu6', 'elu', 'brelu', 'relu', 'prelu', 'sigmoid', 'relu6', 'elu', 'brelu',
'leaky_relu' 'leaky_relu'
]: ]:
tmp = _activation_op_args(op) tmp = self._activation_op_args(op)
elif op.type() == 'batch_norm': elif op.type() == 'batch_norm':
tmp = _batch_norm_op_args(op) tmp = self._batch_norm_op_args(op)
elif op.type() == 'pool2d': elif op.type() == 'pool2d':
tmp = _pooling_op_args(op) tmp = self._pooling_op_args(op)
elif op.type() == 'batch_norm': elif op.type() == 'batch_norm':
tmp = _batch_norm_op_args(op) tmp = self._batch_norm_op_args(op)
elif op.type() == 'softmax': elif op.type() == 'softmax':
tmp = _softmax_op_args(op) tmp = self._softmax_op_args(op)
elif op.type() == 'mul': elif op.type() == 'mul':
tmp = _fc_op_args(op) tmp = self._fc_op_args(op)
else: else:
tmp = None tmp = None
if tmp: if tmp:
ops.append(tmp) ops.append(tmp)
return ops return ops
def _conv_op_args(op): def _conv_op_args(self, op):
assert isinstance(op, OpWrapper) assert isinstance(op, OpWrapper)
tmp, res = [], [] tmp, res = [], []
# op_name # op_name
...@@ -69,11 +71,11 @@ class LatencyEvaluator(object): ...@@ -69,11 +71,11 @@ class LatencyEvaluator(object):
# batch size # batch size
tmp.append(1) tmp.append(1)
# channels, height, width # channels, height, width
in_shapes = op.inputs('Input')[0].shape in_shapes = op.inputs('Input')[0].shape()
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])] tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
# output channels # output channels
w_shapes = op.inputs('Filter')[0].shape w_shapes = op.inputs('Filter')[0].shape()
tmp.append(int(w_shapes[0])) tmp.append(int(w_shapes[0]))
# group # group
...@@ -104,7 +106,7 @@ class LatencyEvaluator(object): ...@@ -104,7 +106,7 @@ class LatencyEvaluator(object):
tmp = tmp + res tmp = tmp + res
return tmp return tmp
def _batch_norm_op_args(op): def _batch_norm_op_args(self, op):
tmp = [] tmp = []
# op name # op name
tmp.append('batch_norm') tmp.append('batch_norm')
...@@ -116,11 +118,11 @@ class LatencyEvaluator(object): ...@@ -116,11 +118,11 @@ class LatencyEvaluator(object):
# batch size # batch size
tmp.append(1) tmp.append(1)
# input channels, height, width # input channels, height, width
in_shapes = op.inputs("X")[0].shape in_shapes = op.inputs("X")[0].shape()
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])] tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
return tmp return tmp
def _eltwise_op_args(op): def _eltwise_op_args(self, op):
# op name # op name
tmp = ['eltwise'] tmp = ['eltwise']
# elementwise type, TODO: add more ops # elementwise type, TODO: add more ops
...@@ -133,7 +135,7 @@ class LatencyEvaluator(object): ...@@ -133,7 +135,7 @@ class LatencyEvaluator(object):
# batch size # batch size
tmp.append(1) tmp.append(1)
# input channels, height, width # input channels, height, width
in_shapes = op.inputs('X')[0].shape in_shapes = op.inputs('X')[0].shape()
while len(in_shapes) < 4: while len(in_shapes) < 4:
in_shapes = in_shapes + (1, ) in_shapes = in_shapes + (1, )
...@@ -141,14 +143,14 @@ class LatencyEvaluator(object): ...@@ -141,14 +143,14 @@ class LatencyEvaluator(object):
tmp.append(int(in_shapes[i])) tmp.append(int(in_shapes[i]))
return tmp return tmp
def _activation_op_args(op): def _activation_op_args(self, op):
tmp = [] tmp = []
# activation type # activation type
tmp.append(op.type()) tmp.append(op.type())
# batch size # batch size
tmp.append(1) tmp.append(1)
# input channels, height, width # input channels, height, width
in_shapes = op.inputs('X')[0].shape in_shapes = op.inputs('X')[0].shape()
while len(in_shapes) < 4: while len(in_shapes) < 4:
in_shapes = in_shapes + (1, ) in_shapes = in_shapes + (1, )
...@@ -156,7 +158,7 @@ class LatencyEvaluator(object): ...@@ -156,7 +158,7 @@ class LatencyEvaluator(object):
tmp.append(int(in_shapes[i])) tmp.append(int(in_shapes[i]))
return tmp return tmp
def _pooling_op_args(op): def _pooling_op_args(self, op):
tmp, res = [], [] tmp, res = [], []
# op name # op name
tmp.append('pooling') tmp.append('pooling')
...@@ -165,7 +167,7 @@ class LatencyEvaluator(object): ...@@ -165,7 +167,7 @@ class LatencyEvaluator(object):
# batch size # batch size
tmp.append(1) tmp.append(1)
# channels, height, width # channels, height, width
in_shapes = op.inputs('X')[0].shape in_shapes = op.inputs('X')[0].shape()
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])] tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
# kernel size # kernel size
ksize = op.attr('ksize') ksize = op.attr('ksize')
...@@ -201,7 +203,7 @@ class LatencyEvaluator(object): ...@@ -201,7 +203,7 @@ class LatencyEvaluator(object):
tmp = tmp + res tmp = tmp + res
return tmp return tmp
def _softmax_op_args(op): def _softmax_op_args(self, op):
# op name # op name
tmp = ['softmax'] tmp = ['softmax']
# axis # axis
...@@ -209,7 +211,7 @@ class LatencyEvaluator(object): ...@@ -209,7 +211,7 @@ class LatencyEvaluator(object):
# batch size # batch size
tmp.append(1) tmp.append(1)
# input channels, height, width # input channels, height, width
in_shapes = op.inputs('X')[0].shape in_shapes = op.inputs('X')[0].shape()
while len(in_shapes) < 4: while len(in_shapes) < 4:
in_shapes = in_shapes + (1, ) in_shapes = in_shapes + (1, )
...@@ -218,7 +220,7 @@ class LatencyEvaluator(object): ...@@ -218,7 +220,7 @@ class LatencyEvaluator(object):
return tmp return tmp
def _fc_op_args(blocks, op): def _fc_op_args(self, op):
# op name # op name
tmp = ['conv'] tmp = ['conv']
# flag bias # flag bias
...@@ -229,12 +231,12 @@ class LatencyEvaluator(object): ...@@ -229,12 +231,12 @@ class LatencyEvaluator(object):
tmp.append(1) tmp.append(1)
# input channels, height, width # input channels, height, width
channels = 1 channels = 1
in_shape = op.inputs('X')[0].shape in_shape = op.inputs('X')[0].shape()
for i in range(1, len(in_shape)): for i in range(1, len(in_shape)):
channels *= in_shape[i] channels *= in_shape[i]
tmp = tmp + [int(channels), 1, 1] tmp = tmp + [int(channels), 1, 1]
# output channels # output channels
tmp.append(int(op.outputs('Out')[0].shape[1])) tmp.append(int(op.outputs('Out')[0].shape()[1]))
# groups, kernel size, padding, stride, dilation # groups, kernel size, padding, stride, dilation
tmp = tmp + [1, 1, 0, 1, 1] tmp = tmp + [1, 1, 0, 1, 1]
return tmp return tmp
...@@ -279,5 +281,6 @@ class TableLatencyEvaluator(LatencyEvaluator): ...@@ -279,5 +281,6 @@ class TableLatencyEvaluator(LatencyEvaluator):
graph = GraphWrapper(graph) graph = GraphWrapper(graph)
assert isinstance(graph, GraphWrapper) assert isinstance(graph, GraphWrapper)
for op in self._get_ops_from_graph(graph): for op in self._get_ops_from_graph(graph):
total_latency += self._op_latency(self._delimiter.join(op)) total_latency += self._op_latency(
self._delimiter.join(map(lambda x: str(x), op)))
return total_latency return total_latency
...@@ -98,7 +98,8 @@ class SAController(EvolutionaryController): ...@@ -98,7 +98,8 @@ class SAController(EvolutionaryController):
self._best_tokens = tokens self._best_tokens = tokens
_logger.info( _logger.info(
"Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}". "Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}".
format(self._iter, self._reward, self._tokens, reward, tokens)) format(self._iter, self._max_reward, self._best_tokens, reward,
tokens))
if self._checkpoints != None: if self._checkpoints != None:
self._save_checkpoint(self._checkpoints) self._save_checkpoint(self._checkpoints)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册