From 072347ff93c61e6e80451b63858965a831a77b2f Mon Sep 17 00:00:00 2001 From: wawltor <980627148@qq.com> Date: Thu, 13 Jun 2019 20:27:12 +0800 Subject: [PATCH] Fix gather and scatter op has same index bug cherry-pick from #17952 test=release/1.5 cherry-pick from #17952 The scatter op has a calc bug when the indices has same index, the scatter op use overwrite mode to calculate the same index, fix this bug by using the accumulate mode to calculate the same index.At the same time, the gather op has the same bug when the op calc the grad. And we use the lib of open-blas and eigen to optimize the time cost in accumulate mode. --- paddle/fluid/API.spec | 4 +- paddle/fluid/operators/gather_op.cc | 7 ++ paddle/fluid/operators/gather_op.cu | 6 +- paddle/fluid/operators/gather_op.h | 13 ++- paddle/fluid/operators/scatter.cu.h | 42 ++++++++- paddle/fluid/operators/scatter.h | 82 +++++++++++++++- paddle/fluid/operators/scatter_op.cc | 8 ++ paddle/fluid/operators/scatter_op.cu | 4 +- paddle/fluid/operators/scatter_op.h | 24 ++++- python/paddle/fluid/layers/nn.py | 18 +++- .../fluid/tests/unittests/test_gather_op.py | 27 ++++++ .../fluid/tests/unittests/test_scatter_op.py | 94 +++++++++++++++++++ 12 files changed, 311 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 77780e4463..99f763f315 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -153,8 +153,8 @@ paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', ' paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', '099b9f051e6247ae661e4a7b4fd3f89a')) paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '746bf58fdb1bd475f8c5f996b05b0e52')) paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '9baf9288c862161ff850d45228047a5e')) -paddle.fluid.layers.gather (ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None), ('document', '01a198d6fff38d5f0d8180a40b228085')) -paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '846a53fd2991bdaab3a8134008eef0c7')) +paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', '3569a6002a96c7f6b5e5bcfdc402df13')) +paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab')) paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '71df5136cf03b06c65027b692fe78f1a')) paddle.fluid.layers.random_crop (ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c9ab9e460ef0a1823249935a30e82c66')) paddle.fluid.layers.mean_iou (ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None), ('document', 'e3b6630ba43cb13dfeeb1601cb64d671')) diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 91f3818f21..cbabd59cf6 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -74,6 +74,13 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The source input of gather op"); AddInput("Index", "The index input of gather op"); AddOutput("Out", "The output of gather op"); + AddAttr( + "overwrite", + "(bool, default: False) " + "In backward process, calc the grad when has same index," + "If true, update the grad using the overwrite mode in same index," + "If false, using the accumulate mode in same index.") + .SetDefault(true); AddComment(R"DOC( Gather Operator. diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 7a0b290ec8..061f92c76c 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -76,9 +76,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); if (index_type == framework::proto::VarType::INT32) { - GPUScatterAssign(ctx.device_context(), *dO, *index, dX); + GPUScatterAssign(ctx, *dO, *index, dX, + ctx.Attr("overwrite")); } else if (index_type == framework::proto::VarType::INT64) { - GPUScatterAssign(ctx.device_context(), *dO, *index, dX); + GPUScatterAssign(ctx, *dO, *index, dX, + ctx.Attr("overwrite")); } } }; diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index a58f794efa..852790a4c6 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -71,6 +71,7 @@ class GatherGradientOpKernel : public framework::OpKernel { .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; + bool overwrite = ctx.Attr("overwrite"); const auto &index_type = index->type(); bool index_type_match = index_type == framework::proto::VarType::INT32 || @@ -82,9 +83,17 @@ class GatherGradientOpKernel : public framework::OpKernel { paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); if (index_type == framework::proto::VarType::INT32) { - ScatterAssign(ctx.device_context(), *dO, *index, dX); + if (overwrite) { + ScatterAssign(ctx.device_context(), *dO, *index, dX); + } else { + ScatterAssignAdd(ctx, *dO, *index, dX); + } } else if (index_type == framework::proto::VarType::INT64) { - ScatterAssign(ctx.device_context(), *dO, *index, dX); + if (overwrite) { + ScatterAssign(ctx.device_context(), *dO, *index, dX); + } else { + ScatterAssignAdd(ctx, *dO, *index, dX); + } } } }; diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index 030719baa8..ce4af44266 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include "math/math_function.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -24,17 +27,33 @@ using Tensor = framework::Tensor; #define CUDA_1D_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) +template +__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, + size_t index_size, size_t slice_size, + bool overwrite) { + CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + int indices_i = i / slice_size; + int slice_i = i - indices_i * slice_size; // offset inside the slice + IndexT scatter_i = indices[indices_i]; + IndexT out_i = scatter_i * slice_size + slice_i; + *(output + out_i) = static_cast(0); + } +} template __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, T* output, size_t index_size, - size_t slice_size) { + size_t slice_size, bool overwrite) { CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT scatter_i = indices[indices_i]; IndexT out_i = scatter_i * slice_size + slice_i; - *(output + out_i) = *(params + i); + if (overwrite) { + *(output + out_i) = *(params + i); + } else { + paddle::platform::CudaAtomicAdd(output + out_i, *(params + i)); + } } } @@ -47,10 +66,13 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, * return: output tensor */ template -void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, - const Tensor& index, Tensor* output) { +void GPUScatterAssign(const framework::ExecutionContext& context, + const Tensor& src, const Tensor& index, Tensor* output, + bool overwrite = true) { // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D + + const auto& ctx = context.device_context(); PADDLE_ENFORCE(index.dims().size() == 1 || (index.dims().size() == 2 && index.dims()[1] == 1)); int index_size = index.dims()[0]; @@ -66,15 +88,25 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, const T* p_src = src.data(); const IndexT* p_index = index.data(); T* p_output = output->data(); + const size_t& slice_bytes = slice_size * sizeof(T); + // set block and grid num int block = 512; int n = slice_size * index_size; int grid = (n + block - 1) / block; + // if not overwrite mode, init data + if (!overwrite) { + ScatterInitCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_index, p_output, index_size, slice_size, overwrite); + } + ScatterCUDAKernel<<< grid, block, 0, reinterpret_cast(ctx).stream()>>>( - p_src, p_index, p_output, index_size, slice_size); + p_src, p_index, p_output, index_size, slice_size, overwrite); } } // namespace operators diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 17d7d82144..680dc282c1 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -14,11 +14,14 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/place.h" +#include "unordered_set" namespace paddle { namespace operators { @@ -26,7 +29,42 @@ namespace operators { using Tensor = framework::Tensor; /** - * Return a updated tensor from source tensor, scattered according to index: + * Return the updated array pointer, use blas or eigen lib to optimize time + * cost + */ +template +typename std::enable_if::value>::type +elementwise_inner_add(const framework::ExecutionContext& ctx, + const T* src_pointer, const T* dist_pointer, + T* result_dist_pointer, const framework::Tensor& src, + framework::Tensor* dist, const int& src_index, + const IndexT& dist_index, const int& slice_size, + const size_t& slice_bytes) { + auto blas = math::GetBlas(ctx); + + blas.VADD(slice_size, src_pointer + src_index * slice_size, + dist_pointer + dist_index * slice_size, + result_dist_pointer + dist_index * slice_size); +} + +template +typename std::enable_if::value>::type +elementwise_inner_add(const framework::ExecutionContext& ctx, + const T* src_pointer, const T* dist_pointer, + T* result_dist_pointer, const framework::Tensor& src, + framework::Tensor* dist, const int& src_index, + const IndexT& dist_index, const int& slice_size, + const size_t& slice_bytes) { + auto src_slice = src.Slice(src_index, src_index + 1); + auto dist_slice = dist->Slice(dist_index, dist_index + 1); + + auto eigen_src = framework::EigenVector::Flatten(src_slice); + auto eigen_dist = framework::EigenVector::Flatten(dist_slice); + + eigen_dist += eigen_src; +} +/** + * Return an updated tensor from source tensor, scattered according to index: * dst[i] = src[index[i]] * input[src]: type-T source Tensor * input[index]: type-IndexT index Tensor (1-D) @@ -64,5 +102,47 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, } } +template +void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.device_context().GetPlace())); + // check index of shape 1-D + PADDLE_ENFORCE(index.dims().size() == 1 || + (index.dims().size() == 2 && index.dims()[1] == 1)); + int index_size = index.dims()[0]; + + auto src_dims = src.dims(); + auto dst_dims = output->dims(); + + const T* p_src = src.data(); + const IndexT* p_index = index.data(); + + const T* p_output = output->data(); + T* result_p_output = output->data(); + + // check src shape and dst shape should match + for (int i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); + + // slice size + size_t slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + const size_t& slice_bytes = slice_size * sizeof(T); + + // if not in overwrite mode, need to init output data + for (int i = 0; i < index_size; ++i) { + const IndexT& index_ = p_index[i]; + memset(result_p_output + slice_size * index_, 0, slice_bytes); + } + + for (int i = 0; i < index_size; ++i) { + const IndexT& index_ = p_index[i]; + elementwise_inner_add(ctx, p_src, p_output, result_p_output, src, + output, i, index_, slice_size, + slice_bytes); + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 68ad223b3c..f5a1b32e5c 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -80,6 +80,14 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Ids", "The index input of scatter op where X will be updated"); AddInput("Updates", "The updated value of scatter op"); AddOutput("Out", "The output of scatter op"); + AddAttr("overwrite", + "(bool, defalut: True) " + "The mode that updating the output when has same index," + "If True, use the overwrite mode to update the output" + "of the same index, if False, use the accumulate mode to" + "update the output of the same index,Default value is True." + "You can set overwrite=False to implement scatter_add.") + .SetDefault(true); AddComment(R"DOC( Scatter Operator. diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index a70b909172..e9ad347538 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -30,10 +30,10 @@ class ScatterOpCUDAKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *Updates = ctx.Input("Updates"); auto *Out = ctx.Output("Out"); + bool overwrite = ctx.Attr("overwrite"); Out->ShareDataWith(*X); - - GPUScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); } }; diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 2eefbba972..9c237dc0f1 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -33,11 +33,33 @@ class ScatterOpKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *Updates = ctx.Input("Updates"); auto *Out = ctx.Output("Out"); + double overwrite = ctx.Attr("overwrite"); // In place output: Out = X, Out[Ids] = Updates framework::TensorCopySync(*X, ctx.GetPlace(), Out); // Apply ScatterUpdate: Out[index] = Updates[:] - ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE( + index_type_match, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (overwrite) { + if (index_type == framework::proto::VarType::INT32) { + ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + } else { + ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + } + } else { + if (index_type == framework::proto::VarType::INT32) { + ScatterAssignAdd(ctx, *Updates, *Ids, Out); + } else { + ScatterAssignAdd(ctx, *Updates, *Ids, Out); + } + } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 44b6794f34..a5dd8a6f08 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7855,7 +7855,7 @@ def image_resize_short(input, out_short_len, resample='BILINEAR'): return image_resize(input=input, out_shape=out_shape, resample=resample) -def gather(input, index): +def gather(input, index, overwrite=True): """ **Gather Layer** @@ -7886,6 +7886,12 @@ def gather(input, index): Args: input (Variable): The source input with rank>=1. index (Variable): The index input with rank=1. + overwrite (bool): The mode that updating the grad when has same index. + If True, use the overwrite mode to update the grad of the same index, + if False, use the accumulate mode to update the grad of the same index. + Default value is True. + + Returns: output (Variable): The output is a tensor with the same rank as input. @@ -7905,11 +7911,12 @@ def gather(input, index): type="gather", inputs={"X": input, "Index": index}, - outputs={"Out": out}) + outputs={"Out": out}, + attrs={'overwrite': overwrite}) return out -def scatter(input, index, updates, name=None): +def scatter(input, index, updates, name=None, overwrite=True): """ **Scatter Layer** @@ -7927,6 +7934,10 @@ def scatter(input, index, updates, name=None): int32 or int64 as it is used as indexes. updates (Variable): The updated value of scatter op. name (str|None): The output variable name. Default None. + overwrite (bool): The mode that updating the output when has same index. + If True, use the overwrite mode to update the output of the same index, + if False, use the accumulate mode to update the output of the same index. + Default value is True.You can set overwrite=False to implement scatter_add. Returns: output (Variable): The output is a tensor with the same shape as input. @@ -7951,6 +7962,7 @@ def scatter(input, index, updates, name=None): inputs={"X": input, "Ids": index, "Updates": updates}, + attrs={'overwrite': overwrite}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index daa5e60498..119f64ce73 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -79,5 +79,32 @@ class TestCase3(TestGatherOp): self.index_type = "int64" +class TestCase4(TestGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'overwrite': False} + self.x_type = "double" + self.index = [1, 1] + self.index_type = "int32" + + +class TestCase5(TestGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'overwrite': False} + self.x_type = "float" + self.index = [1, 1, 3] + self.index_type = "int32" + + +class TestCase6(TestGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'overwrite': True} + self.x_type = "float" + self.index = [1, 3] + self.index_type = "int32" + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 088996f9d7..9c60a11828 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core class TestScatterOp(OpTest): @@ -37,5 +38,98 @@ class TestScatterOp(OpTest): self.check_grad(['Updates'], 'Out', in_place=True) +class TestScatterOp0(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.attrs = {'overwrite': True} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +class TestScatterOp1(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + zeros_np = np.zeros([2, 3]).astype('float32') + index_np = np.array([1, 1]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = zeros_np + for i in range(0, len(index_np)): + output_np[index_np[i]] += updates_np[i] + self.attrs = {'overwrite': False} + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestScatterOp2(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestScatterOp3(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + zeros_np = np.zeros([2, 3]).astype('float32') + index_np = np.array([1, 1]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = zeros_np + for i in range(0, len(index_np)): + output_np[index_np[i]] += updates_np[i] + self.attrs = {'overwrite': False} + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + + if __name__ == "__main__": unittest.main() -- GitLab