From 8eb134c3c1be1b488716223587bd789863c925b2 Mon Sep 17 00:00:00 2001 From: wawltor <980627148@qq.com> Date: Wed, 12 Jun 2019 18:04:18 +0800 Subject: [PATCH] Fix scatter and gather op when has duplicate index (#17952) * test=develop The scatter op has a calc bug when the indices has same index, the scatter op use overwrite mode to calculate the same index, fix this bug by using the accumulate mode to calculate the same index.At the same time, the gather op has the same bug when the op calc the grad. And we use the lib of open-blas and eigen to optimize the time cost in accumulate mode. * test=develop Fix some code format problem, and the same time add the test case in gather and scatter op --- paddle/fluid/API.spec | 4 +- paddle/fluid/operators/gather_op.cc | 7 ++ paddle/fluid/operators/gather_op.cu | 6 +- paddle/fluid/operators/gather_op.h | 13 ++- paddle/fluid/operators/scatter.cu.h | 42 ++++++++- paddle/fluid/operators/scatter.h | 82 +++++++++++++++- paddle/fluid/operators/scatter_op.cc | 8 ++ paddle/fluid/operators/scatter_op.cu | 4 +- paddle/fluid/operators/scatter_op.h | 24 ++++- python/paddle/fluid/layers/nn.py | 18 +++- .../fluid/tests/unittests/test_gather_op.py | 27 ++++++ .../fluid/tests/unittests/test_scatter_op.py | 94 +++++++++++++++++++ 12 files changed, 311 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 3ee4f045e36..6a323e5d794 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -153,8 +153,8 @@ paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', ' paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', '099b9f051e6247ae661e4a7b4fd3f89a')) paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '746bf58fdb1bd475f8c5f996b05b0e52')) paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '9baf9288c862161ff850d45228047a5e')) -paddle.fluid.layers.gather (ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None), ('document', '01a198d6fff38d5f0d8180a40b228085')) -paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '846a53fd2991bdaab3a8134008eef0c7')) +paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', '3569a6002a96c7f6b5e5bcfdc402df13')) +paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab')) paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '71df5136cf03b06c65027b692fe78f1a')) paddle.fluid.layers.random_crop (ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c9ab9e460ef0a1823249935a30e82c66')) paddle.fluid.layers.mean_iou (ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None), ('document', 'e3b6630ba43cb13dfeeb1601cb64d671')) diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 91f3818f216..cbabd59cf63 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -74,6 +74,13 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The source input of gather op"); AddInput("Index", "The index input of gather op"); AddOutput("Out", "The output of gather op"); + AddAttr( + "overwrite", + "(bool, default: False) " + "In backward process, calc the grad when has same index," + "If true, update the grad using the overwrite mode in same index," + "If false, using the accumulate mode in same index.") + .SetDefault(true); AddComment(R"DOC( Gather Operator. diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 7a0b290ec86..061f92c76c3 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -76,9 +76,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); if (index_type == framework::proto::VarType::INT32) { - GPUScatterAssign(ctx.device_context(), *dO, *index, dX); + GPUScatterAssign(ctx, *dO, *index, dX, + ctx.Attr("overwrite")); } else if (index_type == framework::proto::VarType::INT64) { - GPUScatterAssign(ctx.device_context(), *dO, *index, dX); + GPUScatterAssign(ctx, *dO, *index, dX, + ctx.Attr("overwrite")); } } }; diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index a58f794efa9..852790a4c63 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -71,6 +71,7 @@ class GatherGradientOpKernel : public framework::OpKernel { .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; + bool overwrite = ctx.Attr("overwrite"); const auto &index_type = index->type(); bool index_type_match = index_type == framework::proto::VarType::INT32 || @@ -82,9 +83,17 @@ class GatherGradientOpKernel : public framework::OpKernel { paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); if (index_type == framework::proto::VarType::INT32) { - ScatterAssign(ctx.device_context(), *dO, *index, dX); + if (overwrite) { + ScatterAssign(ctx.device_context(), *dO, *index, dX); + } else { + ScatterAssignAdd(ctx, *dO, *index, dX); + } } else if (index_type == framework::proto::VarType::INT64) { - ScatterAssign(ctx.device_context(), *dO, *index, dX); + if (overwrite) { + ScatterAssign(ctx.device_context(), *dO, *index, dX); + } else { + ScatterAssignAdd(ctx, *dO, *index, dX); + } } } }; diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index 030719baa8f..ce4af44266e 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include "math/math_function.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -24,17 +27,33 @@ using Tensor = framework::Tensor; #define CUDA_1D_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) +template +__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, + size_t index_size, size_t slice_size, + bool overwrite) { + CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + int indices_i = i / slice_size; + int slice_i = i - indices_i * slice_size; // offset inside the slice + IndexT scatter_i = indices[indices_i]; + IndexT out_i = scatter_i * slice_size + slice_i; + *(output + out_i) = static_cast(0); + } +} template __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, T* output, size_t index_size, - size_t slice_size) { + size_t slice_size, bool overwrite) { CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT scatter_i = indices[indices_i]; IndexT out_i = scatter_i * slice_size + slice_i; - *(output + out_i) = *(params + i); + if (overwrite) { + *(output + out_i) = *(params + i); + } else { + paddle::platform::CudaAtomicAdd(output + out_i, *(params + i)); + } } } @@ -47,10 +66,13 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, * return: output tensor */ template -void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, - const Tensor& index, Tensor* output) { +void GPUScatterAssign(const framework::ExecutionContext& context, + const Tensor& src, const Tensor& index, Tensor* output, + bool overwrite = true) { // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D + + const auto& ctx = context.device_context(); PADDLE_ENFORCE(index.dims().size() == 1 || (index.dims().size() == 2 && index.dims()[1] == 1)); int index_size = index.dims()[0]; @@ -66,15 +88,25 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, const T* p_src = src.data(); const IndexT* p_index = index.data(); T* p_output = output->data(); + const size_t& slice_bytes = slice_size * sizeof(T); + // set block and grid num int block = 512; int n = slice_size * index_size; int grid = (n + block - 1) / block; + // if not overwrite mode, init data + if (!overwrite) { + ScatterInitCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_index, p_output, index_size, slice_size, overwrite); + } + ScatterCUDAKernel<<< grid, block, 0, reinterpret_cast(ctx).stream()>>>( - p_src, p_index, p_output, index_size, slice_size); + p_src, p_index, p_output, index_size, slice_size, overwrite); } } // namespace operators diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 17d7d82144d..680dc282c14 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -14,11 +14,14 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/place.h" +#include "unordered_set" namespace paddle { namespace operators { @@ -26,7 +29,42 @@ namespace operators { using Tensor = framework::Tensor; /** - * Return a updated tensor from source tensor, scattered according to index: + * Return the updated array pointer, use blas or eigen lib to optimize time + * cost + */ +template +typename std::enable_if::value>::type +elementwise_inner_add(const framework::ExecutionContext& ctx, + const T* src_pointer, const T* dist_pointer, + T* result_dist_pointer, const framework::Tensor& src, + framework::Tensor* dist, const int& src_index, + const IndexT& dist_index, const int& slice_size, + const size_t& slice_bytes) { + auto blas = math::GetBlas(ctx); + + blas.VADD(slice_size, src_pointer + src_index * slice_size, + dist_pointer + dist_index * slice_size, + result_dist_pointer + dist_index * slice_size); +} + +template +typename std::enable_if::value>::type +elementwise_inner_add(const framework::ExecutionContext& ctx, + const T* src_pointer, const T* dist_pointer, + T* result_dist_pointer, const framework::Tensor& src, + framework::Tensor* dist, const int& src_index, + const IndexT& dist_index, const int& slice_size, + const size_t& slice_bytes) { + auto src_slice = src.Slice(src_index, src_index + 1); + auto dist_slice = dist->Slice(dist_index, dist_index + 1); + + auto eigen_src = framework::EigenVector::Flatten(src_slice); + auto eigen_dist = framework::EigenVector::Flatten(dist_slice); + + eigen_dist += eigen_src; +} +/** + * Return an updated tensor from source tensor, scattered according to index: * dst[i] = src[index[i]] * input[src]: type-T source Tensor * input[index]: type-IndexT index Tensor (1-D) @@ -64,5 +102,47 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, } } +template +void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.device_context().GetPlace())); + // check index of shape 1-D + PADDLE_ENFORCE(index.dims().size() == 1 || + (index.dims().size() == 2 && index.dims()[1] == 1)); + int index_size = index.dims()[0]; + + auto src_dims = src.dims(); + auto dst_dims = output->dims(); + + const T* p_src = src.data(); + const IndexT* p_index = index.data(); + + const T* p_output = output->data(); + T* result_p_output = output->data(); + + // check src shape and dst shape should match + for (int i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); + + // slice size + size_t slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + const size_t& slice_bytes = slice_size * sizeof(T); + + // if not in overwrite mode, need to init output data + for (int i = 0; i < index_size; ++i) { + const IndexT& index_ = p_index[i]; + memset(result_p_output + slice_size * index_, 0, slice_bytes); + } + + for (int i = 0; i < index_size; ++i) { + const IndexT& index_ = p_index[i]; + elementwise_inner_add(ctx, p_src, p_output, result_p_output, src, + output, i, index_, slice_size, + slice_bytes); + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 68ad223b3c3..f5a1b32e5c2 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -80,6 +80,14 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Ids", "The index input of scatter op where X will be updated"); AddInput("Updates", "The updated value of scatter op"); AddOutput("Out", "The output of scatter op"); + AddAttr("overwrite", + "(bool, defalut: True) " + "The mode that updating the output when has same index," + "If True, use the overwrite mode to update the output" + "of the same index, if False, use the accumulate mode to" + "update the output of the same index,Default value is True." + "You can set overwrite=False to implement scatter_add.") + .SetDefault(true); AddComment(R"DOC( Scatter Operator. diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index a70b9091727..e9ad3475381 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -30,10 +30,10 @@ class ScatterOpCUDAKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *Updates = ctx.Input("Updates"); auto *Out = ctx.Output("Out"); + bool overwrite = ctx.Attr("overwrite"); Out->ShareDataWith(*X); - - GPUScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); } }; diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 2eefbba9726..9c237dc0f1f 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -33,11 +33,33 @@ class ScatterOpKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *Updates = ctx.Input("Updates"); auto *Out = ctx.Output("Out"); + double overwrite = ctx.Attr("overwrite"); // In place output: Out = X, Out[Ids] = Updates framework::TensorCopySync(*X, ctx.GetPlace(), Out); // Apply ScatterUpdate: Out[index] = Updates[:] - ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE( + index_type_match, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (overwrite) { + if (index_type == framework::proto::VarType::INT32) { + ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + } else { + ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + } + } else { + if (index_type == framework::proto::VarType::INT32) { + ScatterAssignAdd(ctx, *Updates, *Ids, Out); + } else { + ScatterAssignAdd(ctx, *Updates, *Ids, Out); + } + } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9c947cce543..07a65969f91 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7876,7 +7876,7 @@ def image_resize_short(input, out_short_len, resample='BILINEAR'): return image_resize(input=input, out_shape=out_shape, resample=resample) -def gather(input, index): +def gather(input, index, overwrite=True): """ **Gather Layer** @@ -7907,6 +7907,12 @@ def gather(input, index): Args: input (Variable): The source input with rank>=1. index (Variable): The index input with rank=1. + overwrite (bool): The mode that updating the grad when has same index. + If True, use the overwrite mode to update the grad of the same index, + if False, use the accumulate mode to update the grad of the same index. + Default value is True. + + Returns: output (Variable): The output is a tensor with the same rank as input. @@ -7926,11 +7932,12 @@ def gather(input, index): type="gather", inputs={"X": input, "Index": index}, - outputs={"Out": out}) + outputs={"Out": out}, + attrs={'overwrite': overwrite}) return out -def scatter(input, index, updates, name=None): +def scatter(input, index, updates, name=None, overwrite=True): """ **Scatter Layer** @@ -7948,6 +7955,10 @@ def scatter(input, index, updates, name=None): int32 or int64 as it is used as indexes. updates (Variable): The updated value of scatter op. name (str|None): The output variable name. Default None. + overwrite (bool): The mode that updating the output when has same index. + If True, use the overwrite mode to update the output of the same index, + if False, use the accumulate mode to update the output of the same index. + Default value is True.You can set overwrite=False to implement scatter_add. Returns: output (Variable): The output is a tensor with the same shape as input. @@ -7972,6 +7983,7 @@ def scatter(input, index, updates, name=None): inputs={"X": input, "Ids": index, "Updates": updates}, + attrs={'overwrite': overwrite}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index daa5e60498e..119f64ce734 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -79,5 +79,32 @@ class TestCase3(TestGatherOp): self.index_type = "int64" +class TestCase4(TestGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'overwrite': False} + self.x_type = "double" + self.index = [1, 1] + self.index_type = "int32" + + +class TestCase5(TestGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'overwrite': False} + self.x_type = "float" + self.index = [1, 1, 3] + self.index_type = "int32" + + +class TestCase6(TestGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'overwrite': True} + self.x_type = "float" + self.index = [1, 3] + self.index_type = "int32" + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 088996f9d7d..9c60a118285 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core class TestScatterOp(OpTest): @@ -37,5 +38,98 @@ class TestScatterOp(OpTest): self.check_grad(['Updates'], 'Out', in_place=True) +class TestScatterOp0(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.attrs = {'overwrite': True} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +class TestScatterOp1(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + zeros_np = np.zeros([2, 3]).astype('float32') + index_np = np.array([1, 1]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = zeros_np + for i in range(0, len(index_np)): + output_np[index_np[i]] += updates_np[i] + self.attrs = {'overwrite': False} + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestScatterOp2(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestScatterOp3(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + zeros_np = np.zeros([2, 3]).astype('float32') + index_np = np.array([1, 1]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = zeros_np + for i in range(0, len(index_np)): + output_np[index_np[i]] += updates_np[i] + self.attrs = {'overwrite': False} + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + + if __name__ == "__main__": unittest.main() -- GitLab