提交 1d63b06b 编写于 作者: P phlrain

add grad test unit; test=develop

上级 24fa1f4b
......@@ -437,8 +437,17 @@ class OpTest(unittest.TestCase):
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
op_attrs)
cache_list = None
if hasattr(self, "cache_name_list"):
cache_list = self.cache_name_list
self.op = create_op(
self.scope,
self.op_type,
op_inputs,
op_outputs,
op_attrs,
cache_list=cache_list)
if no_grad_set is None:
no_grad_set = set()
......
......@@ -121,9 +121,9 @@ class TestCUDNNLstmOp(OpTest):
self.op_type = "cudnn_lstm"
self.dtype = np.float32
num_steps = 50
batch_size = 20
hidden_size = 200
num_steps = 20
batch_size = 5
hidden_size = 20
input_weight_size = (hidden_size * hidden_size) * 4
hidden_weight_size = (hidden_size * hidden_size) * 4
......@@ -175,6 +175,15 @@ class TestCUDNNLstmOp(OpTest):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
def test_grad_with_place(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'W', 'InitH', 'InitC']),
['Out', 'last_h', 'last_c'],
max_relative_error=0.02)
def testcuda(self):
return core.is_compiled_with_cuda()
......
......@@ -20,7 +20,7 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator
def create_op(scope, op_type, inputs, outputs, attrs):
def create_op(scope, op_type, inputs, outputs, attrs, cache_list=None):
kwargs = dict()
op_maker = core.op_proto_and_checker_maker
......@@ -43,6 +43,11 @@ def create_op(scope, op_type, inputs, outputs, attrs):
__create_var__(in_name, sub_in_name)
else:
__create_var__(in_name, in_name)
if cache_list != None and isinstance(cache_list, list):
for name in cache_list:
kwargs[name] = []
scope.var(name)
kwargs[name].append(name)
for out_name, out_dup in Operator.get_op_outputs(op_type):
if out_name in outputs:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册