未验证 提交 3c457a38 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix scatter_nd_add and gather bug (#35544)

* fix scatter_add_nd and gather bug

* fix gather compile error
上级 5f369881
......@@ -32,9 +32,9 @@ template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t index_size,
size_t slice_size) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
IndexT params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
......@@ -42,13 +42,13 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
}
template <typename T, typename IndexT = int>
__global__ void GatherNdCUDAKernel(const T* input, const int* input_dims,
__global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
const IndexT* indices, T* output,
size_t remain_size, size_t slice_size,
size_t end_size) {
CUDA_KERNEL_LOOP(i, remain_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = 0;
int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) {
......@@ -92,14 +92,14 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
}
// index size
int index_size = index.dims()[0];
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
......@@ -107,8 +107,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
T* p_output = output->data<T>();
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
GatherCUDAKernel<T, IndexT><<<
grid, block, 0,
......@@ -143,21 +143,21 @@ void GPUGatherNd(const framework::ExecutionContext& context,
slice_size *= input_dims[i];
}
// source dim
std::vector<int> v_input_dims(input_dims_size);
std::vector<int64_t> v_input_dims(input_dims_size);
for (int i = 0; i < input_dims_size; ++i) {
v_input_dims[i] = static_cast<int>(input_dims[i]);
v_input_dims[i] = input_dims[i];
}
auto& dev_ctx = context.cuda_device_context();
int bytes = input_dims_size * sizeof(int);
int64_t bytes = input_dims_size * sizeof(int64_t);
auto p_input_dims = memory::Alloc(dev_ctx, bytes);
int* g_input_dims = reinterpret_cast<int*>(p_input_dims->ptr());
int64_t* g_input_dims = reinterpret_cast<int64_t*>(p_input_dims->ptr());
memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes,
ctx.stream());
int block = 512;
int n = slice_size * remain_numel;
int grid = (n + block - 1) / block;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;
GatherNdCUDAKernel<T, IndexT><<<
grid, block, 0,
......@@ -168,16 +168,16 @@ void GPUGatherNd(const framework::ExecutionContext& context,
template <typename T, typename U>
__global__ void GatherGPUKernel(const T* input, const U* index, T* out,
int outer_dim_size, int inner_dim_size,
int out_index_dim_size,
int input_index_dim_size, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int outer_size = outer_dim_size * out_index_dim_size;
int64_t outer_dim_size, int64_t inner_dim_size,
int64_t out_index_dim_size,
int64_t input_index_dim_size, int64_t size) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
int64_t outer_size = outer_dim_size * out_index_dim_size;
for (; idx < size; idx += blockDim.x * gridDim.x) {
int inner_dim_index = idx / outer_size;
int next_idx = idx - outer_size * inner_dim_index;
int index_dim_index = next_idx / outer_dim_size;
int index_val = index[index_dim_index];
int64_t inner_dim_index = idx / outer_size;
int64_t next_idx = idx - outer_size * inner_dim_index;
int64_t index_dim_index = next_idx / outer_dim_size;
U index_val = index[index_dim_index];
PADDLE_ENFORCE(
index_val >= 0 && index_val < input_index_dim_size,
......@@ -187,8 +187,8 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
"be less than [%d] and greater than or equal to 0, but received [%d]",
input_index_dim_size, index_val);
int out_dim_index = next_idx - outer_dim_size * index_dim_index;
int input_index =
int64_t out_dim_index = next_idx - outer_dim_size * index_dim_index;
int64_t input_index =
inner_dim_index * (outer_dim_size * input_index_dim_size) +
index_val * outer_dim_size + out_dim_index;
out[idx] = input[input_index];
......@@ -197,17 +197,19 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
template <typename T, typename U>
__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
int outer_dim_size, int inner_dim_size,
int input_index_dim_size,
int out_index_dim_size, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int64_t outer_dim_size,
int64_t inner_dim_size,
int64_t input_index_dim_size,
int64_t out_index_dim_size, int64_t size) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
int inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
int next_idx = idx % (outer_dim_size * input_index_dim_size);
int index_dim_index = next_idx / (outer_dim_size);
int out_dim_index = next_idx % outer_dim_size;
int out_index = inner_dim_index * (outer_dim_size * out_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index;
int64_t inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
int64_t next_idx = idx % (outer_dim_size * input_index_dim_size);
int64_t index_dim_index = next_idx / (outer_dim_size);
int64_t out_dim_index = next_idx % outer_dim_size;
int64_t out_index =
inner_dim_index * (outer_dim_size * out_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index;
paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx));
}
}
......@@ -217,8 +219,8 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
const int axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
int index_size = index->numel();
int input_size = input->numel();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
auto* index_data = index->data<U>();
......@@ -226,11 +228,11 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
if (input->numel() == 0) return;
int axis_index = axis;
int index_dim_size = input_dim[axis_index];
int64_t index_dim_size = input_dim[axis_index];
int inner_dim_size = 1;
int outer_dim_size = 1;
std::vector<int> out_dim_vec;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
std::vector<int64_t> out_dim_vec;
for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
......@@ -245,7 +247,7 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place);
int out_size = out->numel();
int64_t out_size = out->numel();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size);
......@@ -262,17 +264,17 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
auto* index_data = index->data<U>();
int index_size = index->numel();
int input_size = input->numel();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
if (input->numel() == 0) return;
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size = input_dim[axis_index];
int inner_dim_size = 1;
int outer_dim_size = 1;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
......@@ -284,7 +286,7 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
auto* out_data = out->mutable_data<T>(place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims();
int out_index_dim_size = out_dim[axis_index];
int64_t out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);
platform::GpuLaunchConfig config =
......
......@@ -65,10 +65,10 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
T* p_output = output->data<T>();
// slice size
int slice_size = 1;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// input size
int input_size = src_dims[0] * slice_size;
int64_t input_size = src_dims[0] * slice_size;
const size_t slice_bytes = slice_size * sizeof(T);
......@@ -144,16 +144,16 @@ template <typename T, typename U>
void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
Tensor* out, const paddle::platform::Place& place) {
auto* index_data = index->data<U>();
int index_size = index->numel();
int input_size = input->numel();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
if (input->numel() == 0) return;
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];
for (int i = 0; i < index_size; i++) {
int64_t input_index_dim_size = input_dim[axis_index];
for (int64_t i = 0; i < index_size; i++) {
PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size,
platform::errors::OutOfRange(
"The element of Index must be less than the size of "
......@@ -168,9 +168,9 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
index_data[i], i));
}
int inner_dim_size = 1;
int outer_dim_size = 1;
std::vector<int> out_dim_vec;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
std::vector<int64_t> out_dim_vec;
for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
......@@ -187,11 +187,11 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
auto* out_data = out->mutable_data<T>(place);
int out_index = 0;
for (int i = 0; i < inner_dim_size; i++) {
for (int j = 0; j < index_size; j++) {
for (int k = 0; k < outer_dim_size; k++) {
int index = k + index_data[j] * outer_dim_size +
(i * input_size / inner_dim_size);
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < index_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = k + index_data[j] * outer_dim_size +
(i * input_size / inner_dim_size);
out_data[out_index] = input_data[index];
out_index++;
}
......@@ -210,10 +210,10 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index,
if (input->numel() == 0) return;
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size = input_dim[axis_index];
int inner_dim_size = 1;
int outer_dim_size = 1;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
......@@ -225,14 +225,14 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index,
auto* out_data = out->mutable_data<T>(place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims();
int out_index_dim_size = out_dim[axis_index];
int64_t out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);
for (int i = 0; i < inner_dim_size; i++) {
for (int j = 0; j < input_index_dim_size; j++) {
for (int k = 0; k < outer_dim_size; k++) {
int index = k + index_data[j] * outer_dim_size +
i * outer_dim_size * out_index_dim_size;
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < input_index_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = k + index_data[j] * outer_dim_size +
i * outer_dim_size * out_index_dim_size;
out_data[index] += input_data[j * outer_dim_size + k];
}
}
......
......@@ -35,34 +35,30 @@ using Tensor = framework::Tensor;
template <typename T, typename IndexT = int>
typename std::enable_if<std::is_floating_point<T>::value>::type
elementwise_inner_add(const framework::ExecutionContext& ctx,
const T* src_pointer, const T* dist_pointer,
T* result_dist_pointer, const framework::Tensor& src,
framework::Tensor* dist, const int& src_index,
const IndexT& dist_index, const int& slice_size,
const size_t& slice_bytes) {
const T* src_pointer, T* dst_pointer, size_t src_index,
IndexT dst_index, size_t slice_size) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
blas.VADD(slice_size, src_pointer + src_index * slice_size,
dist_pointer + dist_index * slice_size,
result_dist_pointer + dist_index * slice_size);
dst_pointer + dst_index * slice_size,
dst_pointer + dst_index * slice_size);
}
template <typename T, typename IndexT = int>
typename std::enable_if<!std::is_floating_point<T>::value>::type
elementwise_inner_add(const framework::ExecutionContext& ctx,
const T* src_pointer, const T* dist_pointer,
T* result_dist_pointer, const framework::Tensor& src,
framework::Tensor* dist, const int& src_index,
const IndexT& dist_index, const int& slice_size,
const size_t& slice_bytes) {
auto src_slice = src.Slice(src_index, src_index + 1);
auto dist_slice = dist->Slice(dist_index, dist_index + 1);
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dist = framework::EigenVector<T>::Flatten(dist_slice);
eigen_dist += eigen_src;
const T* src_pointer, T* dst_pointer, size_t src_index,
IndexT dst_index, size_t slice_size) {
using EigenVector = typename framework::EigenTensor<T, 1>::Type;
using ConstEigenVector = typename framework::EigenTensor<T, 1>::ConstType;
framework::EigenDim<1>::Type dim;
dim[0] = slice_size;
ConstEigenVector eigen_src(src_pointer + src_index * slice_size, dim);
EigenVector eigen_dst(dst_pointer + dst_index * slice_size, dim);
eigen_dst += eigen_src;
}
/**
* Return an updated tensor from source tensor, scattered according to index:
* dst[i] = src[index[i]]
......@@ -91,7 +87,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
"But received value is [%d]",
index.dims().size()));
}
int index_size = index.dims()[0];
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
auto dst_dims = output->dims();
......@@ -146,7 +142,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
"expect index'dims shape is 1 or 2 and index.dims[1] is 1"
"but got index'dims shape is %d",
index.dims().size()));
int index_size = index.dims()[0];
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
auto dst_dims = output->dims();
......@@ -154,8 +150,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
const T* p_output = output->data<T>();
T* result_p_output = output->data<T>();
T* p_output = output->data<T>();
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
......@@ -174,26 +169,25 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
const size_t& slice_bytes = slice_size * sizeof(T);
// if not in overwrite mode, need to init output data
for (int i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i];
memset(result_p_output + slice_size * index_, 0, slice_bytes);
for (int64_t i = 0; i < index_size; ++i) {
const IndexT& index_val = p_index[i];
memset(p_output + slice_size * index_val, 0, slice_bytes);
}
// if not in overwrite mode, need to init output data
for (int i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i];
const IndexT& index_val = p_index[i];
PADDLE_ENFORCE_GE(index_, 0,
PADDLE_ENFORCE_GE(index_val, 0,
platform::errors::OutOfRange(
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be greater than or equal to 0, but received [%d]",
index_));
index_val));
elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, result_p_output, src,
output, i, index_, slice_size,
slice_bytes);
elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, i, index_val,
slice_size);
}
}
......@@ -202,14 +196,14 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
template <typename T, typename IndexT = int>
void CPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index,
Tensor* output) {
int index_size = index.dims()[0];
int64_t index_size = index.dims()[0];
auto dst_dims = output->dims();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
size_t slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
for (int64_t i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i];
memset(p_output + slice_size * index_, 0, slice_bytes);
}
......@@ -231,8 +225,7 @@ void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
const T* p_update = update.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* result_p_output = output->data<T>();
const T* p_output = output->data<T>();
T* p_output = output->data<T>();
// final dim
int64_t end_size = index_dims[index_dims_size - 1];
......@@ -244,10 +237,9 @@ void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
for (int64_t i = end_size; i < output_dims_size; ++i) {
slice_size *= output_dims[i];
}
const size_t slice_bytes = slice_size * sizeof(T);
for (int64_t i = 0; i < remain_numel; ++i) {
IndexT index_ = 0;
IndexT index_val = 0;
IndexT temp = 1;
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = p_index[i * end_size + j];
......@@ -260,12 +252,11 @@ void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
"be less than [%d] and greater or equal to 0, but received [%d]",
output_dims[j], index_value));
index_ += (index_value * temp);
index_val += (index_value * temp);
temp *= output_dims[j];
}
elementwise_inner_add<T, IndexT>(ctx, p_update, p_output, result_p_output,
update, output, i, index_, slice_size,
slice_bytes);
elementwise_inner_add<T, IndexT>(ctx, p_update, p_output, i, index_val,
slice_size);
}
}
......
......@@ -20,6 +20,7 @@ from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.framework import core
from paddle.fluid.dygraph.base import switch_to_static_graph
def gather_numpy(x, index, axis):
......@@ -247,6 +248,36 @@ class API_TestDygraphGather(unittest.TestCase):
self.assertTrue(np.allclose(output_np, expected_output))
paddle.enable_static()
def test_large_data(self):
if not paddle.is_compiled_with_cuda():
return
x = np.random.rand(226862, 256).astype("float32")
index = np.random.randint(0, 22682, size=(11859027))
def test_dygraph():
with fluid.dygraph.guard():
gpu_out = paddle.gather(
paddle.to_tensor(x), paddle.to_tensor(index))
return gpu_out.numpy()
@switch_to_static_graph
def test_static_graph():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
index_t = paddle.static.data(
name="index", dtype=index.dtype, shape=index.shape)
out_t = paddle.gather(x_t, index_t)
feed = {x_t.name: x, index_t.name: index}
fetch = [out_t]
gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
return gpu_value
self.assertTrue(np.array_equal(test_dygraph(), test_static_graph()))
class TestGathertError(unittest.TestCase):
def test_error1(self):
......
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle
from paddle.fluid.dygraph.base import switch_to_static_graph
def numpy_scatter_nd(ref, index, updates, fun):
......@@ -227,6 +228,50 @@ class TestScatterNdOpAPI(unittest.TestCase):
output4 = fluid.layers.scatter_nd(
index4, updates4, shape4, name='scatter_nd')
def testcase5(self):
if not fluid.core.is_compiled_with_cuda():
return
shape = [2, 3, 4]
x = np.arange(int(np.prod(shape))).reshape(shape)
index = np.array([[0, 0, 2], [0, 1, 2]])
val = np.array([-1, -3])
with fluid.dygraph.guard():
device = paddle.get_device()
paddle.set_device('gpu')
gpu_value = paddle.scatter_nd_add(
paddle.to_tensor(x),
paddle.to_tensor(index), paddle.to_tensor(val))
paddle.set_device('cpu')
cpu_value = paddle.scatter_nd_add(
paddle.to_tensor(x),
paddle.to_tensor(index), paddle.to_tensor(val))
self.assertTrue(
np.array_equal(gpu_value.numpy(), cpu_value.numpy()))
paddle.set_device(device)
@switch_to_static_graph
def test_static_graph():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
index_t = paddle.static.data(
name="index", dtype=index.dtype, shape=index.shape)
val_t = paddle.static.data(
name="val", dtype=val.dtype, shape=val.shape)
out_t = paddle.scatter_nd_add(x_t, index_t, val_t)
feed = {x_t.name: x, index_t.name: index, val_t.name: val}
fetch = [out_t]
gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
cpu_exe = paddle.static.Executor(paddle.CPUPlace())
cpu_value = cpu_exe.run(feed=feed, fetch_list=fetch)[0]
self.assertTrue(np.array_equal(gpu_value, cpu_value))
test_static_graph()
#Test Raise Error
class TestScatterNdOpRaise(unittest.TestCase):
......@@ -304,4 +349,5 @@ class TestDygraph(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -235,4 +235,5 @@ class TestScatterInplaceAPI(TestScatterAPI):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册