未验证 提交 33f13067 编写于 作者: L Leo Chen 提交者: GitHub

update layers used in mnist dygraph model, test=develop (#21947)

* update layers used in mnist dygraph model, test=develop

* fix import issue, test=develop

* add dygraph utils, test=develop

* add unittest, test=develop
上级 64baee41
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import core
from ..framework import dygraph_only
@dygraph_only
def _append_activation_in_dygraph(input,
act=None,
use_cudnn=False,
use_mkldnn=False):
"""Append activation in dygraph mode.
Args:
input: the input variable.
act: activation type
use_mkldnn: if use mkldnn
use_cudnn: if use cudnn
Return the Variable after append activation
"""
if not act:
return input
attrs = {'use_cudnn': use_cudnn, 'use_mkldnn': use_mkldnn}
inputs = {"X": [input]}
act_op = getattr(core.ops, act)
res = act_op(inputs, attrs)
return res['Out'][0]
@dygraph_only
def _append_bias_in_dygraph(
input,
bias=None,
axis=1, ):
"""Append bias operation in dygraph mode.
Args:
input: the input variable.
bias: the bias to be appended
axis: the axis to perform operation
Return the Variable after bias operation
"""
if not bias:
return input
attrs = {'axis': axis}
inputs = {'X': [input], 'Y': [bias]}
outs = core.ops.elementwise_add(inputs, attrs)
return outs['Out'][0]
......@@ -15,9 +15,9 @@
from __future__ import print_function
from six.moves import reduce
from .. import core
from ..layers import utils
from ..dygraph import dygraph_utils
from . import layers
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer_
from ..param_attr import ParamAttr
......@@ -235,6 +235,29 @@ class Conv2D(layers.Layer):
self._bias_param = value
def forward(self, input):
inputs = {
'Input': [input],
'Filter': [self._filter_param],
}
attrs = {
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups if self._groups else 1,
'use_cudnn': self._use_cudnn,
'use_mkldnn': False,
}
if in_dygraph_mode():
outs = core.ops.conv2d(inputs, attrs)
pre_bias = outs['Output'][0]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias,
self._bias_param, 1)
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
......@@ -245,14 +268,7 @@ class Conv2D(layers.Layer):
'Filter': self._filter_param,
},
outputs={"Output": pre_bias},
attrs={
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups if self._groups else 1,
'use_cudnn': self._use_cudnn,
'use_mkldnn': False,
})
attrs=attrs)
if self._bias_param is not None:
pre_act = self._helper.create_variable_for_type_inference(
......@@ -858,23 +874,30 @@ class Pool2D(layers.Layer):
self._l_type = 'pool2d'
def forward(self, input):
attrs = {
"pooling_type": self._pool_type,
"ksize": self._pool_size,
"global_pooling": self._global_pooling,
"strides": self._pool_stride,
"paddings": self._pool_padding,
"use_cudnn": self._use_cudnn,
"ceil_mode": self._ceil_mode,
"use_mkldnn": False,
"exclusive": self._exclusive,
}
inputs = {"X": [input]}
if in_dygraph_mode():
outs = core.ops.pool2d(inputs, attrs)
return outs['Out'][0]
pool_out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type=self._l_type,
inputs={"X": input},
outputs={"Out": pool_out},
attrs={
"pooling_type": self._pool_type,
"ksize": self._pool_size,
"global_pooling": self._global_pooling,
"strides": self._pool_stride,
"paddings": self._pool_padding,
"use_cudnn": self._use_cudnn,
"ceil_mode": self._ceil_mode,
"use_mkldnn": False,
"exclusive": self._exclusive,
})
attrs=attrs)
return pool_out
......@@ -948,17 +971,26 @@ class Linear(layers.Layer):
shape=[output_dim], attr=bias_attr, dtype=dtype, is_bias=True)
def forward(self, input):
attrs = {
"transpose_X": False,
"transpose_Y": False,
"alpha": 1,
}
inputs = {"X": [input], "Y": [self.weight]}
if in_dygraph_mode():
outs = core.ops.matmul(inputs, attrs)
pre_bias = outs['Out'][0]
pre_act = dygraph_utils._append_bias_in_dygraph(
pre_bias, self.bias, axis=len(input.shape) - 1)
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="matmul",
inputs={"X": input,
"Y": self.weight},
outputs={"Out": tmp},
attrs={
"transpose_X": False,
"transpose_Y": False,
"alpha": 1,
})
type="matmul", inputs=inputs, outputs={"Out": tmp}, attrs=attrs)
if self.bias:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
......@@ -1548,6 +1580,7 @@ class Embedding(layers.Layer):
'remote_prefetch': self._remote_prefetch,
'padding_idx': self._padding_idx
}
if in_dygraph_mode():
inputs = {'Ids': [input], 'W': [self._w]}
outs = core.ops.lookup_table_v2(inputs, attrs)
......
......@@ -234,36 +234,46 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
predict = fluid.layers.fc(input=x, size=class_num, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
"""
if not soft_label:
return cross_entropy2(input, label, ignore_index)
inputs = {'X': [input], 'Label': [label]}
attrs = {"soft_label": soft_label, "ignore_index": ignore_index}
if in_dygraph_mode():
outs = core.ops.cross_entropy(inputs, attrs)
return outs['Y'][0]
check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'cross_entropy')
if not soft_label:
return cross_entropy2(input, label, ignore_index)
helper = LayerHelper('cross_entropy', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs={"soft_label": soft_label,
"ignore_index": ignore_index})
type='cross_entropy', inputs=inputs, outputs={'Y': [out]}, attrs=attrs)
return out
def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
inputs = {'X': [input], 'Label': [label]}
attrs = {'ignore_index': ignore_index}
if in_dygraph_mode():
outs = core.ops.cross_entropy2(inputs, attrs)
return outs['Y'][0]
check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'cross_entropy2')
helper = LayerHelper('cross_entropy2', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
xshape = helper.create_variable_for_type_inference(dtype=input.dtype)
match_x = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy2',
inputs={'X': [input],
'Label': [label]},
inputs=inputs,
outputs={'Y': [out],
'MatchX': [match_x],
'XShape': [xshape]},
attrs={'ignore_index': ignore_index})
attrs=attrs)
return out
......
......@@ -20,7 +20,8 @@ from __future__ import print_function
import warnings
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..framework import Variable, in_dygraph_mode, _varbase_creator
from .. import core
from ..param_attr import ParamAttr
from . import nn
from ..data_feeder import check_type_and_dtype
......@@ -71,6 +72,26 @@ def accuracy(input, label, k=1, correct=None, total=None):
#[array([0.6666667], dtype=float32)]
"""
if in_dygraph_mode():
topk_out, topk_indices = nn.topk(input, k=k)
inputs = {
"Out": [topk_out],
"Indices": [topk_indices],
"Label": [label]
}
acc_out = _varbase_creator(dtype="float32")
if correct is None:
correct = _varbase_creator(dtype="int64")
if total is None:
total = _varbase_creator(dtype="int64")
outputs = {
"Accuracy": [acc_out],
"Correct": [correct],
"Total": [total]
}
outs = core.ops.accuracy(inputs, {}, outputs)
return outs['Accuracy'][0]
helper = LayerHelper("accuracy", **locals())
check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'accuracy')
......
......@@ -26,6 +26,7 @@ from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program
from ..dygraph import base
from ..dygraph import dygraph_utils
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat, assign, fill_constant, zeros
......@@ -186,28 +187,6 @@ __all__ = [
]
@dygraph_only
def _append_activation_in_dygraph(input,
act=None,
use_cudnn=False,
use_mkldnn=False):
"""Append activation in dygraph mode.
Args:
input: the input variable.
act: activation type
use_mkldnn: if use mkldnn
use_cudnn: if use cudnn
Return the Variable after append activation
"""
attrs = {'use_cudnn': use_cudnn, 'use_mkldnn': use_mkldnn}
inputs = {"X": [input]}
act_op = getattr(core.ops, act)
res = act_op(inputs, attrs)
return res['Out'][0]
@dygraph_only
def _elementwise_op_in_dygraph(x,
y,
......@@ -219,13 +198,10 @@ def _elementwise_op_in_dygraph(x,
inputs = {'X': [x], 'Y': [y]}
op = getattr(core.ops, op_name)
outs = op(inputs, attrs)
pre_act = outs['Out'][0]
out = outs['Out'][0]
if not act:
return pre_act
else:
return _append_activation_in_dygraph(
pre_act, act, use_mkldnn=use_mkldnn)
return dygraph_utils._append_activation_in_dygraph(
out, act, use_mkldnn=use_mkldnn)
def fc(input,
......@@ -4736,15 +4712,23 @@ def topk(input, k, name=None):
vk_values, vk_indices = layers.topk(input2, k=vk) #vk_values.shape=[None, 13, k], vk_indices.shape=[None, 13, k]
"""
helper = LayerHelper("top_k", **locals())
values = helper.create_variable_for_type_inference(dtype=input.dtype)
indices = helper.create_variable_for_type_inference(dtype="int64")
inputs = {"X": [input]}
attrs = None
attrs = {}
if isinstance(k, Variable):
inputs['K'] = k
inputs['K'] = [k]
else:
attrs = {'k': k}
if in_dygraph_mode():
outs = core.ops.top_k(inputs, attrs)
outs['Out'][0].stop_gradient = True
outs['Indices'][0].stop_gradient = True
return outs['Out'][0], outs['Indices'][0]
helper = LayerHelper("top_k", **locals())
values = helper.create_variable_for_type_inference(dtype=input.dtype)
indices = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="top_k",
inputs=inputs,
......@@ -5594,11 +5578,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
inputs = {'X': [x]}
outs = core.ops.reshape2(inputs, attrs)
pre_act = outs['Out'][0]
if act is None:
return pre_act
else:
return _append_activation_in_dygraph(pre_act, act)
out = outs['Out'][0]
return dygraph_utils._append_activation_in_dygraph(out, act)
check_type_and_dtype(x, 'x', Variable,
['float16', 'float32', 'float64', 'int32', 'int64'],
......@@ -11333,6 +11314,10 @@ def mean(x, name=None):
name='data', shape=[2, 3], dtype='float32')
mean = fluid.layers.mean(input)
"""
if in_dygraph_mode():
inputs = {"X": [x]}
outs = core.ops.mean(inputs)
return outs['Out'][0]
helper = LayerHelper("mean", **locals())
check_type_and_dtype(x, 'x', Variable, ['float16', 'float32', 'float64'],
......
......@@ -1712,20 +1712,20 @@ class AdamOptimizer(Optimizer):
# create the adam optimize op
inputs = {
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad),
"Moment1": moment1,
"Moment2": moment2,
"Beta1Pow": beta1_pow_acc,
"Beta2Pow": beta2_pow_acc
"Param": [param_and_grad[0]],
"Grad": [param_and_grad[1]],
"LearningRate": [self._create_param_lr(param_and_grad)],
"Moment1": [moment1],
"Moment2": [moment2],
"Beta1Pow": [beta1_pow_acc],
"Beta2Pow": [beta2_pow_acc]
}
outputs = {
"ParamOut": param_and_grad[0],
"Moment1Out": moment1,
"Moment2Out": moment2,
"Beta1PowOut": beta1_pow_acc,
"Beta2PowOut": beta2_pow_acc,
"ParamOut": [param_and_grad[0]],
"Moment1Out": [moment1],
"Moment2Out": [moment2],
"Beta1PowOut": [beta1_pow_acc],
"Beta2PowOut": [beta2_pow_acc],
}
attrs = {
"epsilon": self._epsilon,
......@@ -1742,6 +1742,10 @@ class AdamOptimizer(Optimizer):
else:
attrs['beta2'] = self._beta2
if framework.in_dygraph_mode():
core.ops.adam(inputs, attrs, outputs)
return None
adam_op = block.append_op(
type=self.type,
inputs=inputs,
......
......@@ -284,7 +284,7 @@ class TestDygraphResnet(unittest.TestCase):
if traced_layer is not None:
resnet.eval()
traced_layer._switch(is_test=True)
out_dygraph = resnet([img])
out_dygraph = resnet(img)
out_static = traced_layer([img])
traced_layer._switch(is_test=False)
helper.assertEachVar(out_dygraph, out_static)
......
......@@ -1627,6 +1627,34 @@ class TestLayer(LayerTest):
self.assertIsNotNone(out2)
self.assertIsNotNone(out3)
def test_accuracy(self):
x = np.random.rand(3, 32, 32).astype("float32")
y = np.array([[1], [0], [1]])
with self.static_graph():
data = fluid.data(name="input", shape=[-1, 32, 32], dtype="float32")
label = fluid.data(name="label", shape=[-1, 1], dtype="int")
fc_out = fluid.layers.fc(input=data, size=10)
predict = fluid.layers.softmax(input=fc_out)
result = fluid.layers.accuracy(input=predict, label=label, k=5)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
x = np.random.rand(3, 32, 32).astype("float32")
y = np.array([[1], [0], [1]])
static_out = exe.run(feed={"input": x,
"label": y},
fetch_list=result[0])
with self.dynamic_graph():
data = base.to_variable(x)
label = base.to_variable(y)
fc_out = fluid.layers.fc(data, size=10)
predict = fluid.layers.softmax(fc_out)
dynamic_out = fluid.layers.accuracy(input=predict, label=label, k=5)
self.assertTrue(np.array_equal(static_out[0], dynamic_out.numpy()))
class TestBook(LayerTest):
def test_all_layers(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册