未验证 提交 a96e8bc9 编写于 作者: T TTerror 提交者: GitHub

fix gather op and add logsumexp op on kunlun (#32931)

* fix gather op and add logsumexp op on kunlun

* update xpu depence

* update tests and fix elementwise_add
上级 be8e94aa
...@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT) ...@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT)
elseif(WITH_SUNWAY) elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
else() else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_04_09_2.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_05_19.tar.gz" CACHE STRING "" FORCE)
endif() endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
......
...@@ -141,6 +141,7 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -141,6 +141,7 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
} }
} }
const T* dz_data = dz->data<T>();
T* dx_data = nullptr; T* dx_data = nullptr;
T* dy_data = nullptr; T* dy_data = nullptr;
if (dx) { if (dx) {
...@@ -152,9 +153,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -152,9 +153,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>(); ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dx_data, dx_data, int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dz_data, dz_data,
dx_data, dz->data<T>(), dy_data, dz_data, dz_data, dy_data, dx_data,
dx_data, x_dims_vec, y_dims_vec); x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS, ret, xpu::SUCCESS,
platform::errors::External( platform::errors::External(
......
...@@ -40,16 +40,6 @@ class GatherOpXPUKernel : public framework::OpKernel<T> { ...@@ -40,16 +40,6 @@ class GatherOpXPUKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
// check index type is INT32
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"XPU only support INT32, it holds %s, but desires to be %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32)));
const auto index_dims = index->dims(); const auto index_dims = index->dims();
if (index_dims.size() == 2) { if (index_dims.size() == 2) {
...@@ -65,14 +55,26 @@ class GatherOpXPUKernel : public framework::OpKernel<T> { ...@@ -65,14 +55,26 @@ class GatherOpXPUKernel : public framework::OpKernel<T> {
"The index should be 1D, when it is not 2D, but we get %d", "The index should be 1D, when it is not 2D, but we get %d",
index_dims.size())); index_dims.size()));
} }
int slice_size = x->numel() / x->dims()[0]; std::vector<int> xshape(x->dims().size());
for (int i = 0; i < x->dims().size(); ++i) {
xshape[i] = x->dims()[i];
}
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int r = int r = XPU_SUCCESS;
xpu::gather<T>(dev_ctx.x_context(), x->data<T>(), index->data<int>(), if (index->type() == framework::proto::VarType::INT32) {
index->dims()[0], slice_size, output->data<T>()); r = xpu::gather<T, int>(dev_ctx.x_context(), x->data<T>(),
PADDLE_ENFORCE_EQ( index->data<int>(), output->data<T>(), xshape,
r, xpu::Error_t::SUCCESS, index->dims()[0], 0);
platform::errors::External("XPU kernel error! error code=%d", r)); } else {
r = xpu::gather<T, int64_t>(dev_ctx.x_context(), x->data<T>(),
index->data<int64_t>(), output->data<T>(),
xshape, index->dims()[0], 0);
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU gather kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} }
}; };
...@@ -93,30 +95,11 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> { ...@@ -93,30 +95,11 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Now, it doesn't support XPU with Axis.")); "Now, it doesn't support XPU with Axis."));
} }
dx->mutable_data<T>(ctx.GetPlace());
const int zero = 0;
int r_dx = xpu::memset(dev_ctx.x_context(), dx->data<T>(), zero,
dx->numel() * sizeof(T));
PADDLE_ENFORCE_EQ(
r_dx, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r_dx));
if (dout->numel() == 0) { if (dout->numel() == 0) {
return; return;
} }
bool overwrite = ctx.Attr<bool>("overwrite");
// check index type is INT32
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"XPU only support INT32, it holds %s, but desires to be %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32)));
bool overwrite = ctx.Attr<bool>("overwrite");
const auto index_dims = index->dims(); const auto index_dims = index->dims();
if (index_dims.size() == 2) { if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -131,16 +114,27 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> { ...@@ -131,16 +114,27 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
"The index should be 1D, when it is not 2D, but we get %d", "The index should be 1D, when it is not 2D, but we get %d",
index_dims.size())); index_dims.size()));
} }
std::vector<int> xshape(dx->dims().size());
for (int i = 0; i < dx->dims().size(); ++i) {
xshape[i] = dx->dims()[i];
}
int index_size = index_dims[0]; dx->mutable_data<T>(ctx.GetPlace());
int slice_size = dout->numel() / dout->dims()[0];
int r = xpu::scatter<T>(dev_ctx.x_context(), dout->data<T>(), int r = XPU_SUCCESS;
index->data<int>(), index_size, slice_size, if (index->type() == framework::proto::VarType::INT32) {
dx->data<T>(), overwrite); r = xpu::gather_grad<T, int>(dev_ctx.x_context(), dout->data<T>(),
PADDLE_ENFORCE_EQ( index->data<int>(), dx->data<T>(), xshape,
r, xpu::Error_t::SUCCESS, index->dims()[0], 0, overwrite);
platform::errors::External("XPU kernel error! error code=%d", r)); } else {
r = xpu::gather_grad<T, int64_t>(dev_ctx.x_context(), dout->data<T>(),
index->data<int64_t>(), dx->data<T>(),
xshape, index->dims()[0], 0, overwrite);
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU gather grad kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} }
}; };
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/xpu_header.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class XPULogsumexpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto axis = context.Attr<std::vector<int>>("axis");
auto reduce_all = context.Attr<bool>("reduce_all");
const auto& input_dim_size = input->dims().size();
// The dims has full dim, set the reduce_all is True
reduce_all |= (static_cast<const int>(axis.size()) == input_dim_size);
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
std::vector<int> axis_shape;
std::vector<int> xdims(input_dim_size);
for (int i = 0; i < input_dim_size; ++i) {
xdims[i] = input->dims()[i];
}
if (reduce_all) {
for (int i = 0; i < input_dim_size; ++i) {
axis_shape.push_back(i);
}
} else {
for (size_t i = 0; i < axis.size(); ++i) {
int rdim = axis[i] < 0 ? axis[i] + input_dim_size : axis[i];
axis_shape.push_back(rdim);
}
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::logsumexp<T>(dev_ctx.x_context(), input_data, output_data,
xdims, axis_shape);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU logsumexp kernel error! error value[%d %]", r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
logsumexp,
ops::XPULogsumexpKernel<paddle::platform::XPUDeviceContext, float>);
#endif
...@@ -13,13 +13,18 @@ ...@@ -13,13 +13,18 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import unittest
import sys import sys
sys.path.append("..") sys.path.append("..")
import unittest
import numpy as np import numpy as np
from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
def gather_numpy(x, index, axis): def gather_numpy(x, index, axis):
...@@ -29,37 +34,12 @@ def gather_numpy(x, index, axis): ...@@ -29,37 +34,12 @@ def gather_numpy(x, index, axis):
return gather return gather
class TestGatherOp(OpTest): class TestXPUGatherOp(XPUOpTest):
def setUp(self):
self.op_type = "gather"
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
'Index': np.array(self.index).astype(self.index_type)
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float64"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestXPUGatherOp(OpTest):
def setUp(self): def setUp(self):
self.dtype = "float32"
self.op_type = "gather" self.op_type = "gather"
self.dtype = np.float32 self.use_xpu = True
self.use_mkldnn = False
self.attrs = {'use_xpu': True} self.attrs = {'use_xpu': True}
self.config() self.config()
...@@ -71,12 +51,12 @@ class TestXPUGatherOp(OpTest): ...@@ -71,12 +51,12 @@ class TestXPUGatherOp(OpTest):
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self): def test_check_output(self):
if self.dtype == np.float32 and paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place)
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float32 and paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out')
...@@ -85,7 +65,7 @@ class TestXPUGatherOp(OpTest): ...@@ -85,7 +65,7 @@ class TestXPUGatherOp(OpTest):
For multi-dimension input For multi-dimension input
""" """
self.x_shape = (10, 20) self.x_shape = (10, 20)
self.x_type = self.dtype self.x_type = "float32"
self.index = [1, 3, 5] self.index = [1, 3, 5]
self.index_type = "int32" self.index_type = "int32"
...@@ -150,5 +130,14 @@ class TestCase6(TestXPUGatherOp): ...@@ -150,5 +130,14 @@ class TestCase6(TestXPUGatherOp):
self.index_type = "int32" self.index_type = "int32"
class TestCase7(TestXPUGatherOp):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': True}
self.x_type = "float32"
self.index = [1, 3]
self.index_type = "int64"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import sys
sys.path.append("..")
import numpy as np
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
def ref_logsumexp(x, axis=None, keepdim=False, reduce_all=False):
if isinstance(axis, int):
axis = (axis, )
elif isinstance(axis, list):
axis = tuple(axis)
if reduce_all:
axis = None
out = np.log(np.exp(x).sum(axis=axis, keepdims=keepdim))
return out
class XPUTestLogsumexp(XPUOpTest):
def setUp(self):
self.op_type = 'logsumexp'
self.shape = [2, 3, 4, 5]
self.dtype = 'float32'
self.axis = [-1]
self.keepdim = False
self.reduce_all = False
self.set_attrs()
np.random.seed(10)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = ref_logsumexp(x, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {
'axis': self.axis,
'keepdim': self.keepdim,
'reduce_all': self.reduce_all
}
def set_attrs(self):
pass
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
pass
class TestLogsumexp_shape(XPUTestLogsumexp):
def set_attrs(self):
self.shape = [4, 5, 6]
class TestLogsumexp_axis(XPUTestLogsumexp):
def set_attrs(self):
self.axis = [0, -1]
class TestLogsumexp_axis_all(XPUTestLogsumexp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
class TestLogsumexp_keepdim(XPUTestLogsumexp):
def set_attrs(self):
self.keepdim = True
class TestLogsumexp_reduce_all(XPUTestLogsumexp):
def set_attrs(self):
self.reduce_all = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册