未验证 提交 374abcf3 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #16247 from panyx0718/imperative

add more imperative layer tests.
......@@ -214,10 +214,8 @@ framework::LoDTensor& VarBase::GradValue() {
}
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (grad_op_descs_.empty() && backward_id_ <= 0) {
VLOG(3) << "op with no grad: " << Type();
return {};
}
PADDLE_ENFORCE(!grad_op_descs_.empty() || backward_id_ > 0,
"%s has no backward implementation", Type());
VLOG(3) << "apply op grad: " << Type();
std::vector<framework::VariableValueMap> tmp_grad_outputs;
......
......@@ -32,11 +32,12 @@ void CreateGradOp(const framework::OpDesc& op_desc,
std::vector<framework::OpDesc*>* grad_op_descs,
std::unordered_map<std::string, std::string>* grad_to_var) {
PADDLE_ENFORCE(grad_op_descs->empty());
std::vector<std::unique_ptr<framework::OpDesc>> descs =
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
const framework::OpInfo& op_info =
framework::OpInfoMap::Instance().Get(op_desc.Type());
if (!op_info.grad_op_maker_) return;
std::vector<std::unique_ptr<framework::OpDesc>> descs =
op_info.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
for (auto& desc : descs) {
grad_op_descs->emplace_back(desc.release());
}
......
......@@ -24,6 +24,7 @@ import inspect
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, _in_imperative_mode
from ..imperative import base
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat, assign
......@@ -9138,6 +9139,10 @@ def _elementwise_op(helper):
op_type = helper.layer_type
x = helper.kwargs.get('x', None)
y = helper.kwargs.get('y', None)
if _in_imperative_mode():
x = base.to_variable(x)
y = base.to_variable(y)
assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type)
axis = helper.kwargs.get('axis', -1)
......
......@@ -174,6 +174,60 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_ret[i], static_ret2[i]))
self.assertTrue(np.allclose(static_ret[i], dy_ret[i]._numpy()))
def test_elementwise_math(self):
n = np.ones([3, 3], dtype='float32')
n2 = np.ones([3, 3], dtype='float32') * 1.1
n3 = np.ones([3, 3], dtype='float32') * 2
n4 = np.ones([3, 3], dtype='float32') * 3
n5 = np.ones([3, 3], dtype='float32') * 4
n6 = np.ones([3, 3], dtype='float32') * 5
with self.static_graph():
t = layers.data(name='t', shape=[3, 3], dtype='float32')
t2 = layers.data(name='t2', shape=[3, 3], dtype='float32')
t3 = layers.data(name='t3', shape=[3, 3], dtype='float32')
t4 = layers.data(name='t4', shape=[3, 3], dtype='float32')
t5 = layers.data(name='t5', shape=[3, 3], dtype='float32')
t6 = layers.data(name='t6', shape=[3, 3], dtype='float32')
ret = layers.elementwise_add(t, t2)
ret = layers.elementwise_pow(ret, t3)
ret = layers.elementwise_div(ret, t4)
ret = layers.elementwise_sub(ret, t5)
ret = layers.elementwise_mul(ret, t6)
static_ret = self.get_static_graph_result(
feed={
't': n,
't2': n2,
't3': n3,
't4': n4,
't5': n5,
't6': n6
},
fetch_list=[ret])[0]
with self.dynamic_graph():
ret = layers.elementwise_add(n, n2)
ret = layers.elementwise_pow(ret, n3)
ret = layers.elementwise_div(ret, n4)
ret = layers.elementwise_sub(ret, n5)
dy_ret = layers.elementwise_mul(ret, n6)
self.assertTrue(
np.allclose(static_ret, dy_ret._numpy()),
'%s vs %s' % (static_ret, dy_ret._numpy()))
def test_elementwise_minmax(self):
n = np.ones([3, 3], dtype='float32')
n2 = np.ones([3, 3], dtype='float32') * 2
with self.dynamic_graph():
min_ret = layers.elementwise_min(n, n2)
max_ret = layers.elementwise_max(n, n2)
self.assertTrue(np.allclose(n, min_ret._numpy()))
self.assertTrue(np.allclose(n2, max_ret._numpy()))
class TestBook(unittest.TestCase):
def test_fit_a_line(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册