未验证 提交 072347ff 编写于 作者: W wawltor 提交者: GitHub

Fix gather and scatter op has same index bug cherry-pick from #17952

test=release/1.5
cherry-pick from #17952
The scatter op has a calc bug when the indices has same index, the scatter op use overwrite mode to calculate the same index, fix this bug by using the accumulate mode to calculate the same index.At the same time, the gather op has the same bug when the op calc the grad. And we use the lib of open-blas and eigen to optimize the time cost in accumulate mode.
上级 80a3fd2e
...@@ -153,8 +153,8 @@ paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', ' ...@@ -153,8 +153,8 @@ paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', '
paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', '099b9f051e6247ae661e4a7b4fd3f89a')) paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', '099b9f051e6247ae661e4a7b4fd3f89a'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '746bf58fdb1bd475f8c5f996b05b0e52')) paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '746bf58fdb1bd475f8c5f996b05b0e52'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '9baf9288c862161ff850d45228047a5e')) paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '9baf9288c862161ff850d45228047a5e'))
paddle.fluid.layers.gather (ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None), ('document', '01a198d6fff38d5f0d8180a40b228085')) paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', '3569a6002a96c7f6b5e5bcfdc402df13'))
paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '846a53fd2991bdaab3a8134008eef0c7')) paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab'))
paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '71df5136cf03b06c65027b692fe78f1a')) paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '71df5136cf03b06c65027b692fe78f1a'))
paddle.fluid.layers.random_crop (ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c9ab9e460ef0a1823249935a30e82c66')) paddle.fluid.layers.random_crop (ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c9ab9e460ef0a1823249935a30e82c66'))
paddle.fluid.layers.mean_iou (ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None), ('document', 'e3b6630ba43cb13dfeeb1601cb64d671')) paddle.fluid.layers.mean_iou (ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None), ('document', 'e3b6630ba43cb13dfeeb1601cb64d671'))
......
...@@ -74,6 +74,13 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -74,6 +74,13 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The source input of gather op"); AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op"); AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of gather op"); AddOutput("Out", "The output of gather op");
AddAttr<bool>(
"overwrite",
"(bool, default: False) "
"In backward process, calc the grad when has same index,"
"If true, update the grad using the overwrite mode in same index,"
"If false, using the accumulate mode in same index.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Gather Operator. Gather Operator.
......
...@@ -76,9 +76,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -76,9 +76,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx.device_context(), *dO, *index, dX); GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
GPUScatterAssign<T, int64_t>(ctx.device_context(), *dO, *index, dX); GPUScatterAssign<T, int64_t>(ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
} }
} }
}; };
......
...@@ -71,6 +71,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -71,6 +71,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
.eigen_device(); .eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
bool overwrite = ctx.Attr<bool>("overwrite");
const auto &index_type = index->type(); const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
...@@ -82,9 +83,17 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -82,9 +83,17 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
ScatterAssign<T, int>(ctx.device_context(), *dO, *index, dX); if (overwrite) {
ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX);
} else {
ScatterAssignAdd<T, int32_t>(ctx, *dO, *index, dX);
}
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
if (overwrite) {
ScatterAssign<T, int64_t>(ctx.device_context(), *dO, *index, dX); ScatterAssign<T, int64_t>(ctx.device_context(), *dO, *index, dX);
} else {
ScatterAssignAdd<T, int64_t>(ctx, *dO, *index, dX);
}
} }
} }
}; };
......
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <unordered_set>
#include "math/math_function.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
...@@ -24,17 +27,33 @@ using Tensor = framework::Tensor; ...@@ -24,17 +27,33 @@ using Tensor = framework::Tensor;
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
template <typename T, typename IndexT = int>
__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
size_t index_size, size_t slice_size,
bool overwrite) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT scatter_i = indices[indices_i];
IndexT out_i = scatter_i * slice_size + slice_i;
*(output + out_i) = static_cast<T>(0);
}
}
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t index_size, T* output, size_t index_size,
size_t slice_size) { size_t slice_size, bool overwrite) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size; int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT scatter_i = indices[indices_i]; IndexT scatter_i = indices[indices_i];
IndexT out_i = scatter_i * slice_size + slice_i; IndexT out_i = scatter_i * slice_size + slice_i;
if (overwrite) {
*(output + out_i) = *(params + i); *(output + out_i) = *(params + i);
} else {
paddle::platform::CudaAtomicAdd(output + out_i, *(params + i));
}
} }
} }
...@@ -47,10 +66,13 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, ...@@ -47,10 +66,13 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
* return: output tensor * return: output tensor
*/ */
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, void GPUScatterAssign(const framework::ExecutionContext& context,
const Tensor& index, Tensor* output) { const Tensor& src, const Tensor& index, Tensor* output,
bool overwrite = true) {
// PADDLE_ENFORCE(platform::is_gpu_place(place)); // PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D // check index of shape 1-D
const auto& ctx = context.device_context();
PADDLE_ENFORCE(index.dims().size() == 1 || PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1)); (index.dims().size() == 2 && index.dims()[1] == 1));
int index_size = index.dims()[0]; int index_size = index.dims()[0];
...@@ -66,15 +88,25 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -66,15 +88,25 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const T* p_src = src.data<T>(); const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T);
// set block and grid num
int block = 512; int block = 512;
int n = slice_size * index_size; int n = slice_size * index_size;
int grid = (n + block - 1) / block; int grid = (n + block - 1) / block;
// if not overwrite mode, init data
if (!overwrite) {
ScatterInitCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_index, p_output, index_size, slice_size, overwrite);
}
ScatterCUDAKernel<T, IndexT><<< ScatterCUDAKernel<T, IndexT><<<
grid, block, 0, grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>( reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size); p_src, p_index, p_output, index_size, slice_size, overwrite);
} }
} // namespace operators } // namespace operators
......
...@@ -14,11 +14,14 @@ limitations under the License. */ ...@@ -14,11 +14,14 @@ limitations under the License. */
#pragma once #pragma once
#include <cstring> #include <cstring>
#include <string>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "unordered_set"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +29,42 @@ namespace operators { ...@@ -26,7 +29,42 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
/** /**
* Return a updated tensor from source tensor, scattered according to index: * Return the updated array pointer, use blas or eigen lib to optimize time
* cost
*/
template <typename T, typename IndexT = int>
typename std::enable_if<std::is_floating_point<T>::value>::type
elementwise_inner_add(const framework::ExecutionContext& ctx,
const T* src_pointer, const T* dist_pointer,
T* result_dist_pointer, const framework::Tensor& src,
framework::Tensor* dist, const int& src_index,
const IndexT& dist_index, const int& slice_size,
const size_t& slice_bytes) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
blas.VADD(slice_size, src_pointer + src_index * slice_size,
dist_pointer + dist_index * slice_size,
result_dist_pointer + dist_index * slice_size);
}
template <typename T, typename IndexT = int>
typename std::enable_if<!std::is_floating_point<T>::value>::type
elementwise_inner_add(const framework::ExecutionContext& ctx,
const T* src_pointer, const T* dist_pointer,
T* result_dist_pointer, const framework::Tensor& src,
framework::Tensor* dist, const int& src_index,
const IndexT& dist_index, const int& slice_size,
const size_t& slice_bytes) {
auto src_slice = src.Slice(src_index, src_index + 1);
auto dist_slice = dist->Slice(dist_index, dist_index + 1);
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dist = framework::EigenVector<T>::Flatten(dist_slice);
eigen_dist += eigen_src;
}
/**
* Return an updated tensor from source tensor, scattered according to index:
* dst[i] = src[index[i]] * dst[i] = src[index[i]]
* input[src]: type-T source Tensor * input[src]: type-T source Tensor
* input[index]: type-IndexT index Tensor (1-D) * input[index]: type-IndexT index Tensor (1-D)
...@@ -64,5 +102,47 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -64,5 +102,47 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
} }
} }
template <typename T, typename IndexT = int>
void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.device_context().GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int index_size = index.dims()[0];
auto src_dims = src.dims();
auto dst_dims = output->dims();
const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
const T* p_output = output->data<T>();
T* result_p_output = output->data<T>();
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
// slice size
size_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const size_t& slice_bytes = slice_size * sizeof(T);
// if not in overwrite mode, need to init output data
for (int i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i];
memset(result_p_output + slice_size * index_, 0, slice_bytes);
}
for (int i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i];
elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, result_p_output, src,
output, i, index_, slice_size,
slice_bytes);
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -80,6 +80,14 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,6 +80,14 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Ids", "The index input of scatter op where X will be updated"); AddInput("Ids", "The index input of scatter op where X will be updated");
AddInput("Updates", "The updated value of scatter op"); AddInput("Updates", "The updated value of scatter op");
AddOutput("Out", "The output of scatter op"); AddOutput("Out", "The output of scatter op");
AddAttr<bool>("overwrite",
"(bool, defalut: True) "
"The mode that updating the output when has same index,"
"If True, use the overwrite mode to update the output"
"of the same index, if False, use the accumulate mode to"
"update the output of the same index,Default value is True."
"You can set overwrite=False to implement scatter_add.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Scatter Operator. Scatter Operator.
......
...@@ -30,10 +30,10 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -30,10 +30,10 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out"); auto *Out = ctx.Output<Tensor>("Out");
bool overwrite = ctx.Attr<bool>("overwrite");
Out->ShareDataWith(*X); Out->ShareDataWith(*X);
GPUScatterAssign<T>(ctx, *Updates, *Ids, Out, overwrite);
GPUScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
} }
}; };
......
...@@ -33,11 +33,33 @@ class ScatterOpKernel : public framework::OpKernel<T> { ...@@ -33,11 +33,33 @@ class ScatterOpKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out"); auto *Out = ctx.Output<Tensor>("Out");
double overwrite = ctx.Attr<bool>("overwrite");
// In place output: Out = X, Out[Ids] = Updates // In place output: Out = X, Out[Ids] = Updates
framework::TensorCopySync(*X, ctx.GetPlace(), Out); framework::TensorCopySync(*X, ctx.GetPlace(), Out);
// Apply ScatterUpdate: Out[index] = Updates[:] // Apply ScatterUpdate: Out[index] = Updates[:]
ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out); const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (overwrite) {
if (index_type == framework::proto::VarType::INT32) {
ScatterAssign<T, int32_t>(ctx.device_context(), *Updates, *Ids, Out);
} else {
ScatterAssign<T, int64_t>(ctx.device_context(), *Updates, *Ids, Out);
}
} else {
if (index_type == framework::proto::VarType::INT32) {
ScatterAssignAdd<T, int32_t>(ctx, *Updates, *Ids, Out);
} else {
ScatterAssignAdd<T, int64_t>(ctx, *Updates, *Ids, Out);
}
}
} }
}; };
......
...@@ -7855,7 +7855,7 @@ def image_resize_short(input, out_short_len, resample='BILINEAR'): ...@@ -7855,7 +7855,7 @@ def image_resize_short(input, out_short_len, resample='BILINEAR'):
return image_resize(input=input, out_shape=out_shape, resample=resample) return image_resize(input=input, out_shape=out_shape, resample=resample)
def gather(input, index): def gather(input, index, overwrite=True):
""" """
**Gather Layer** **Gather Layer**
...@@ -7886,6 +7886,12 @@ def gather(input, index): ...@@ -7886,6 +7886,12 @@ def gather(input, index):
Args: Args:
input (Variable): The source input with rank>=1. input (Variable): The source input with rank>=1.
index (Variable): The index input with rank=1. index (Variable): The index input with rank=1.
overwrite (bool): The mode that updating the grad when has same index.
If True, use the overwrite mode to update the grad of the same index,
if False, use the accumulate mode to update the grad of the same index.
Default value is True.
Returns: Returns:
output (Variable): The output is a tensor with the same rank as input. output (Variable): The output is a tensor with the same rank as input.
...@@ -7905,11 +7911,12 @@ def gather(input, index): ...@@ -7905,11 +7911,12 @@ def gather(input, index):
type="gather", type="gather",
inputs={"X": input, inputs={"X": input,
"Index": index}, "Index": index},
outputs={"Out": out}) outputs={"Out": out},
attrs={'overwrite': overwrite})
return out return out
def scatter(input, index, updates, name=None): def scatter(input, index, updates, name=None, overwrite=True):
""" """
**Scatter Layer** **Scatter Layer**
...@@ -7927,6 +7934,10 @@ def scatter(input, index, updates, name=None): ...@@ -7927,6 +7934,10 @@ def scatter(input, index, updates, name=None):
int32 or int64 as it is used as indexes. int32 or int64 as it is used as indexes.
updates (Variable): The updated value of scatter op. updates (Variable): The updated value of scatter op.
name (str|None): The output variable name. Default None. name (str|None): The output variable name. Default None.
overwrite (bool): The mode that updating the output when has same index.
If True, use the overwrite mode to update the output of the same index,
if False, use the accumulate mode to update the output of the same index.
Default value is True.You can set overwrite=False to implement scatter_add.
Returns: Returns:
output (Variable): The output is a tensor with the same shape as input. output (Variable): The output is a tensor with the same shape as input.
...@@ -7951,6 +7962,7 @@ def scatter(input, index, updates, name=None): ...@@ -7951,6 +7962,7 @@ def scatter(input, index, updates, name=None):
inputs={"X": input, inputs={"X": input,
"Ids": index, "Ids": index,
"Updates": updates}, "Updates": updates},
attrs={'overwrite': overwrite},
outputs={"Out": out}) outputs={"Out": out})
return out return out
......
...@@ -79,5 +79,32 @@ class TestCase3(TestGatherOp): ...@@ -79,5 +79,32 @@ class TestCase3(TestGatherOp):
self.index_type = "int64" self.index_type = "int64"
class TestCase4(TestGatherOp):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': False}
self.x_type = "double"
self.index = [1, 1]
self.index_type = "int32"
class TestCase5(TestGatherOp):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': False}
self.x_type = "float"
self.index = [1, 1, 3]
self.index_type = "int32"
class TestCase6(TestGatherOp):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': True}
self.x_type = "float"
self.index = [1, 3]
self.index_type = "int32"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
class TestScatterOp(OpTest): class TestScatterOp(OpTest):
...@@ -37,5 +38,98 @@ class TestScatterOp(OpTest): ...@@ -37,5 +38,98 @@ class TestScatterOp(OpTest):
self.check_grad(['Updates'], 'Out', in_place=True) self.check_grad(['Updates'], 'Out', in_place=True)
class TestScatterOp0(OpTest):
def setUp(self):
self.op_type = "scatter"
ref_np = np.ones((3, 3)).astype("float32")
index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.attrs = {'overwrite': True}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
class TestScatterOp1(OpTest):
def setUp(self):
self.op_type = "scatter"
ref_np = np.ones((3, 3)).astype("float32")
zeros_np = np.zeros([2, 3]).astype('float32')
index_np = np.array([1, 1]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = zeros_np
for i in range(0, len(index_np)):
output_np[index_np[i]] += updates_np[i]
self.attrs = {'overwrite': False}
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestScatterOp2(OpTest):
def setUp(self):
self.op_type = "scatter"
ref_np = np.ones((3, 3)).astype("float32")
index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestScatterOp3(OpTest):
def setUp(self):
self.op_type = "scatter"
ref_np = np.ones((3, 3)).astype("float32")
zeros_np = np.zeros([2, 3]).astype('float32')
index_np = np.array([1, 1]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = zeros_np
for i in range(0, len(index_np)):
output_np[index_np[i]] += updates_np[i]
self.attrs = {'overwrite': False}
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册