未验证 提交 4643baa7 编写于 作者: T TTerror 提交者: GitHub

add argsort/scatter for kunlun (#38345)

* add argsort/scatter for kunlun

* update test_scatter

* update xpu.cmake

* update xpu.cmake

* fix scatter
上级 3672480b
...@@ -36,7 +36,7 @@ ENDIF() ...@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211129") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211226")
else() else()
SET(XPU_BASE_URL "${XPU_BASE_URL}") SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/argsort_op.h"
namespace paddle {
namespace operators {
const int XPU_SORT_MAX_SIZE = 16384;
template <typename T, typename TID>
static inline void xpu_argsort(xpu::Context* ctx, const T* input_data,
T* output_data, TID* indices_data, int m, int n,
bool descending) {
int ret =
xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending);
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External("XPU sort kernel return wrong value[%d %s].",
ret, XPUAPIErrorMsg[ret]));
}
template <typename T>
static inline void xpu_transpose(xpu::Context* ctx, const T* x, T* y,
const std::vector<int>& xshape,
const std::vector<int>& permute) {
int ret = xpu::transpose(ctx, x, y, xshape, permute);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU transpose kernel return wrong value[%d %s]", ret,
XPUAPIErrorMsg[ret]));
}
template <typename TX, typename TY>
static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) {
int ret = xpu::cast_v2(ctx, x, y, len);
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External("XPU cast kernel return wrong value[%d %s]",
ret, XPUAPIErrorMsg[ret]));
}
template <typename T, bool VALUE_NEED_CAST = false,
bool INDEX_NEED_CAST = false>
struct XPUArgsort {
void operator()(xpu::Context* ctx, const T* input_data, T* output_data,
int64_t* indices_data, const std::vector<int>& data_shape,
const std::vector<int>& permute, bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{data_shape[0], data_shape[2],
data_shape[1]};
T* input_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
T* output_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute);
xpu_argsort(ctx, input_data_trans, output_data_trans, indices_data_trans, m,
n, descending);
xpu_transpose(ctx, output_data_trans, output_data, trans_data_shape,
permute);
xpu_transpose(ctx, indices_data_trans, indices_data, trans_data_shape,
permute);
}
};
template <typename T>
struct XPUArgsort<T, false, true> {
void operator()(xpu::Context* ctx, const T* input_data, T* output_data,
int64_t* indices_data, const std::vector<int>& data_shape,
const std::vector<int>& permute, bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{data_shape[0], data_shape[2],
data_shape[1]};
T* input_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
T* output_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute);
xpu_argsort(ctx, input_data_trans, output_data_trans, indices_data_trans, m,
n, descending);
xpu_transpose(ctx, output_data_trans, output_data, trans_data_shape,
permute);
xpu_cast(ctx, indices_data_trans, cast_data_int64, len);
xpu_transpose(ctx, cast_data_int64, indices_data, trans_data_shape,
permute);
}
};
template <>
struct XPUArgsort<int64_t, true, true> {
void operator()(xpu::Context* ctx, const int64_t* input_data,
int64_t* output_data, int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute, bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{data_shape[0], data_shape[2],
data_shape[1]};
int* input_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* output_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* cast_data_int = RAII_GUARD.alloc_l3_or_gm<int>(len);
int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_cast(ctx, input_data, cast_data_int, len);
xpu_transpose(ctx, cast_data_int, input_data_trans, data_shape, permute);
xpu_argsort(ctx, input_data_trans, output_data_trans, indices_data_trans, m,
n, descending);
xpu_cast(ctx, output_data_trans, cast_data_int64, len);
xpu_transpose(ctx, cast_data_int64, output_data, trans_data_shape, permute);
xpu_cast(ctx, indices_data_trans, cast_data_int64, len);
xpu_transpose(ctx, cast_data_int64, indices_data, trans_data_shape,
permute);
}
};
template <typename T>
class ArgsortXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out");
auto* indices = ctx.Output<framework::Tensor>("Indices");
int axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int n = in_dims[axis];
PADDLE_ENFORCE_LT(
n, XPU_SORT_MAX_SIZE,
platform::errors::InvalidArgument(
"The axis dimension of Input should less than %d, but got %d.",
XPU_SORT_MAX_SIZE, in_dims[axis]));
auto input_data = input->data<T>();
auto output_data = output->mutable_data<T>(ctx.GetPlace());
auto indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int len_before =
framework::product(framework::slice_ddim(in_dims, 0, axis));
int len_after = framework::product(
framework::slice_ddim(in_dims, axis + 1, in_dims.size()));
bool int64_need_cast =
(std::is_same<T, int64_t>::value && n > (XPU_SORT_MAX_SIZE / 2))
? true
: false;
bool index_need_cast = (n > (XPU_SORT_MAX_SIZE / 2)) ? true : false;
std::vector<int> permute_vec{0, 2, 1};
std::vector<int> data_shape{len_before, n, len_after};
if (int64_need_cast) {
XPUArgsort<T, true, true>()(dev_ctx.x_context(), input_data, output_data,
indices_data, data_shape, permute_vec,
descending);
} else if (index_need_cast) {
XPUArgsort<T, false, true>()(dev_ctx.x_context(), input_data, output_data,
indices_data, data_shape, permute_vec,
descending);
} else {
XPUArgsort<T, false, false>()(dev_ctx.x_context(), input_data,
output_data, indices_data, data_shape,
permute_vec, descending);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(argsort, ops::ArgsortXPUKernel<float>,
ops::ArgsortXPUKernel<int>,
ops::ArgsortXPUKernel<int64_t>);
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <memory>
#include <string>
#include "paddle/fluid/operators/scatter_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ScatterOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Ids");
auto *updates = ctx.Input<Tensor>("Updates");
auto *out = ctx.Output<Tensor>("Out");
bool overwrite = ctx.Attr<bool>("overwrite");
// In place output: Out = X, Out[ids] = Updates
framework::TensorCopy(*x, ctx.GetPlace(), out);
// Apply ScatterUpdate: Out[index] = Updates[:]
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
// check index of shape 1-D
PADDLE_ENFORCE_EQ(
index->dims().size() == 1 ||
(index->dims().size() == 2 && index->dims()[1] == 1),
true, platform::errors::InvalidArgument(
"index's shape is error, "
"expect index'dims shape is 1 or 2 and index.dims[1] is 1"
"but got index'dims shape is %d",
index->dims().size()));
int index_size = static_cast<int>(index->dims()[0]);
auto x_dims = x->dims();
auto update_dims = updates->dims();
for (int i = 1; i < x_dims.size(); i++)
PADDLE_ENFORCE_EQ(
x_dims[i], update_dims[i],
platform::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i, x_dims[i], i, update_dims[i]));
int dim0 = static_cast<int>(x->dims()[0]);
int dim1 = static_cast<int>(
framework::product(framework::slice_ddim(x_dims, 1, x_dims.size())));
T *out_data = out->data<T>();
const T *updates_data = updates->data<T>();
auto &dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int r = XPU_SUCCESS;
Tensor indices_cpu(index->type());
framework::TensorCopy(*index, platform::CPUPlace(), &indices_cpu);
if (index_type == framework::proto::VarType::INT32) {
auto index_data = const_cast<int *>(index->data<int>());
xpu::VectorParam<int> indices{indices_cpu.data<int>(), index_size,
index_data};
r = xpu::scatter(dev_ctx.x_context(), updates_data, out_data, indices,
dim0, dim1, overwrite);
} else {
auto index_data = const_cast<int64_t *>(index->data<int64_t>());
xpu::VectorParam<int64_t> indices{indices_cpu.data<int64_t>(), index_size,
index_data};
r = xpu::scatter(dev_ctx.x_context(), updates_data, out_data, indices,
dim0, dim1, overwrite);
}
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU scatter kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(scatter, ops::ScatterOpXPUKernel<float>,
ops::ScatterOpXPUKernel<int64_t>);
#endif
...@@ -32,6 +32,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -32,6 +32,9 @@ XPUOpMap& get_kl2_ops() {
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"argsort", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"assign_value", {"assign_value",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad", {"batch_norm_grad",
...@@ -263,6 +266,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -263,6 +266,8 @@ XPUOpMap& get_kl2_ops() {
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace())})},
{"scatter", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace())})},
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import ParamAttr
from paddle.fluid.framework import Program, grad_var_name
from paddle.fluid.executor import Executor
from paddle.fluid.backward import append_backward
paddle.enable_static()
class TestArgsortOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "argsort"
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.init_inputshape()
self.init_axis()
self.init_direction()
self.x = np.random.random(self.input_shape).astype(self.dtype)
self.inputs = {"X": self.x}
self.attrs = {"axis": self.axis, "descending": self.descending}
self.get_output()
self.outputs = {"Out": self.sorted_x, "Indices": self.indices}
def get_output(self):
if self.descending:
self.indices = np.flip(
np.argsort(
self.x, kind='heapsort', axis=self.axis), self.axis)
self.sorted_x = np.flip(
np.sort(
self.x, kind='heapsort', axis=self.axis), self.axis)
else:
self.indices = np.argsort(self.x, kind='heapsort', axis=self.axis)
self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis)
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
def init_inputshape(self):
self.input_shape = (2, 2, 2, 3, 3)
def init_dtype(self):
self.dtype = 'float32'
def init_axis(self):
self.axis = -1
def test_check_output(self):
self.check_output_with_place(self.place)
def init_direction(self):
self.descending = False
class TestArgsortOpAxis0XPU(TestArgsortOp):
def init_axis(self):
self.axis = 0
class TestArgsortOpAxis1XPU(TestArgsortOp):
def init_axis(self):
self.axis = 1
class TestArgsortOpAxis2XPU(TestArgsortOp):
def init_axis(self):
self.axis = 2
class TestArgsortOpAxisNeg1XPU(TestArgsortOp):
def init_axis(self):
self.axis = -1
class TestArgsortOpAxisNeg2XPU(TestArgsortOp):
def init_axis(self):
self.axis = -2
class TestArgsortOpDescendingAxisXPU(TestArgsortOp):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis0XPU(TestArgsortOpAxis0XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis1XPU(TestArgsortOpAxis1XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis2XPU(TestArgsortOpAxis2XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg1XPU(TestArgsortOpAxisNeg1XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg2XPU(TestArgsortOpAxisNeg2XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpAxis0XPUINT64(TestArgsortOp):
def setUp(self):
self.set_xpu()
self.op_type = "argsort"
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.init_inputshape()
self.init_axis()
self.init_direction()
self.x = np.random.randint(
low=-1000, high=1000, size=self.input_shape).astype(self.dtype)
self.inputs = {"X": self.x}
self.attrs = {"axis": self.axis, "descending": self.descending}
self.get_output()
self.outputs = {"Out": self.sorted_x, "Indices": self.indices}
def init_axis(self):
self.axis = 0
def init_dtype(self):
self.dtype = 'int64'
class TestArgsortOpAxis1XPUINT64(TestArgsortOpAxis0XPUINT64):
def init_axis(self):
self.axis = 1
class TestArgsortOpAxis2XPUINT64(TestArgsortOpAxis0XPUINT64):
def init_axis(self):
self.axis = 2
class TestArgsortOpAxisNeg1XPUINT64(TestArgsortOpAxis0XPUINT64):
def init_axis(self):
self.axis = -1
class TestArgsortOpAxisNeg2XPUINT64(TestArgsortOpAxis0XPUINT64):
def init_axis(self):
self.axis = -2
class TestArgsortOpDescendingAxisXPUINT64(TestArgsortOpAxis0XPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis0XPUINT64(TestArgsortOpAxis0XPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis1XPUINT64(TestArgsortOpAxis1XPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis2XPUINT64(TestArgsortOpAxis2XPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg1XPUINT64(TestArgsortOpAxisNeg1XPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg2XPUINT64(TestArgsortOpAxisNeg2XPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpAxis0XPUINT(TestArgsortOp):
def setUp(self):
self.set_xpu()
self.op_type = "argsort"
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.init_inputshape()
self.init_axis()
self.init_direction()
self.x = np.random.randint(
low=-1000, high=1000, size=self.input_shape).astype(self.dtype)
self.inputs = {"X": self.x}
self.attrs = {"axis": self.axis, "descending": self.descending}
self.get_output()
self.outputs = {"Out": self.sorted_x, "Indices": self.indices}
def init_axis(self):
self.axis = 0
def init_dtype(self):
self.dtype = 'int'
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
paddle.enable_static()
class TestScatterOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "scatter"
self.place = paddle.XPUPlace(0)
ref_np = np.ones((3, 50)).astype("float32")
index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 50)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class TestScatterOp0(TestScatterOp):
def setUp(self):
self.set_xpu()
self.op_type = "scatter"
self.place = paddle.XPUPlace(0)
ref_np = np.ones((3, 3)).astype("float32")
index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.attrs = {'overwrite': True}
self.outputs = {'Out': output_np}
class TestScatterOp1(TestScatterOp):
def setUp(self):
self.set_xpu()
self.op_type = "scatter"
self.place = paddle.XPUPlace(0)
ref_np = np.ones((3, 3)).astype("float32")
zeros_np = np.zeros([2, 3]).astype('float32')
index_np = np.array([1, 1]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = zeros_np
for i in range(0, len(index_np)):
output_np[index_np[i]] += updates_np[i]
self.attrs = {'overwrite': False}
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
class TestScatterOp2(TestScatterOp):
def setUp(self):
self.set_xpu()
self.op_type = "scatter"
self.place = paddle.XPUPlace(0)
ref_np = np.ones((3, 3)).astype("float32")
index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
class TestScatterOp3(TestScatterOp):
def setUp(self):
self.set_xpu()
self.op_type = "scatter"
self.place = paddle.XPUPlace(0)
ref_np = np.ones((3, 3)).astype("float32")
zeros_np = np.zeros([2, 3]).astype('float32')
index_np = np.array([1, 1]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = zeros_np
for i in range(0, len(index_np)):
output_np[index_np[i]] += updates_np[i]
self.attrs = {'overwrite': False}
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
class TestScatterOp4(TestScatterOp):
def setUp(self):
self.set_xpu()
self.op_type = "scatter"
self.place = paddle.XPUPlace(0)
ref_np = np.ones((3, 3)).astype("float32")
index_np = np.array([1, 2]).astype("int64")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
class TestScatterOp5(TestScatterOp):
def setUp(self):
self.set_xpu()
self.op_type = "scatter"
self.place = paddle.XPUPlace(0)
ref_np = np.ones((3, 3)).astype("float32")
index_np = np.array([1, 2]).astype("int64")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
class TestScatterOp6(TestScatterOp):
def setUp(self):
self.set_xpu()
self.op_type = "scatter"
self.place = paddle.XPUPlace(0)
ref_np = np.ones((3, 3)).astype("int64")
index_np = np.array([1, 2]).astype("int64")
updates_np = np.random.random((2, 3)).astype("int64")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册