未验证 提交 63aeb02d 编写于 作者: T TTerror 提交者: GitHub

fix gather op and add logsumexp op on kunlun (#32931) (#33592)

* fix gather op and add logsumexp op on kunlun

* update xpu depence

* update tests and fix elementwise_add
上级 bb5963da
......@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT)
elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_04_09.tar.gz" CACHE STRING "" FORCE)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_05_19.tar.gz" CACHE STRING "" FORCE)
endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
......
......@@ -141,6 +141,7 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
}
}
const T* dz_data = dz->data<T>();
T* dx_data = nullptr;
T* dy_data = nullptr;
if (dx) {
......@@ -152,9 +153,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dx_data, dx_data,
dx_data, dz->data<T>(), dy_data,
dx_data, x_dims_vec, y_dims_vec);
int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dz_data, dz_data,
dz_data, dz_data, dy_data, dx_data,
x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
......
......@@ -40,16 +40,6 @@ class GatherOpXPUKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
// check index type is INT32
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"XPU only support INT32, it holds %s, but desires to be %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32)));
const auto index_dims = index->dims();
if (index_dims.size() == 2) {
......@@ -65,14 +55,26 @@ class GatherOpXPUKernel : public framework::OpKernel<T> {
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
int slice_size = x->numel() / x->dims()[0];
std::vector<int> xshape(x->dims().size());
for (int i = 0; i < x->dims().size(); ++i) {
xshape[i] = x->dims()[i];
}
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int r =
xpu::gather<T>(dev_ctx.x_context(), x->data<T>(), index->data<int>(),
index->dims()[0], slice_size, output->data<T>());
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r));
int r = XPU_SUCCESS;
if (index->type() == framework::proto::VarType::INT32) {
r = xpu::gather<T, int>(dev_ctx.x_context(), x->data<T>(),
index->data<int>(), output->data<T>(), xshape,
index->dims()[0], 0);
} else {
r = xpu::gather<T, int64_t>(dev_ctx.x_context(), x->data<T>(),
index->data<int64_t>(), output->data<T>(),
xshape, index->dims()[0], 0);
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU gather kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
};
......@@ -93,30 +95,11 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
PADDLE_THROW(platform::errors::InvalidArgument(
"Now, it doesn't support XPU with Axis."));
}
dx->mutable_data<T>(ctx.GetPlace());
const int zero = 0;
int r_dx = xpu::memset(dev_ctx.x_context(), dx->data<T>(), zero,
dx->numel() * sizeof(T));
PADDLE_ENFORCE_EQ(
r_dx, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r_dx));
if (dout->numel() == 0) {
return;
}
bool overwrite = ctx.Attr<bool>("overwrite");
// check index type is INT32
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"XPU only support INT32, it holds %s, but desires to be %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32)));
bool overwrite = ctx.Attr<bool>("overwrite");
const auto index_dims = index->dims();
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
......@@ -131,16 +114,27 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
std::vector<int> xshape(dx->dims().size());
for (int i = 0; i < dx->dims().size(); ++i) {
xshape[i] = dx->dims()[i];
}
int index_size = index_dims[0];
int slice_size = dout->numel() / dout->dims()[0];
dx->mutable_data<T>(ctx.GetPlace());
int r = xpu::scatter<T>(dev_ctx.x_context(), dout->data<T>(),
index->data<int>(), index_size, slice_size,
dx->data<T>(), overwrite);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r));
int r = XPU_SUCCESS;
if (index->type() == framework::proto::VarType::INT32) {
r = xpu::gather_grad<T, int>(dev_ctx.x_context(), dout->data<T>(),
index->data<int>(), dx->data<T>(), xshape,
index->dims()[0], 0, overwrite);
} else {
r = xpu::gather_grad<T, int64_t>(dev_ctx.x_context(), dout->data<T>(),
index->data<int64_t>(), dx->data<T>(),
xshape, index->dims()[0], 0, overwrite);
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU gather grad kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
};
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/xpu_header.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class XPULogsumexpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto axis = context.Attr<std::vector<int>>("axis");
auto reduce_all = context.Attr<bool>("reduce_all");
const auto& input_dim_size = input->dims().size();
// The dims has full dim, set the reduce_all is True
reduce_all |= (static_cast<const int>(axis.size()) == input_dim_size);
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
std::vector<int> axis_shape;
std::vector<int> xdims(input_dim_size);
for (int i = 0; i < input_dim_size; ++i) {
xdims[i] = input->dims()[i];
}
if (reduce_all) {
for (int i = 0; i < input_dim_size; ++i) {
axis_shape.push_back(i);
}
} else {
for (size_t i = 0; i < axis.size(); ++i) {
int rdim = axis[i] < 0 ? axis[i] + input_dim_size : axis[i];
axis_shape.push_back(rdim);
}
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::logsumexp<T>(dev_ctx.x_context(), input_data, output_data,
xdims, axis_shape);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU logsumexp kernel error! error value[%d %]", r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
logsumexp,
ops::XPULogsumexpKernel<paddle::platform::XPUDeviceContext, float>);
#endif
......@@ -13,13 +13,18 @@
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
def gather_numpy(x, index, axis):
......@@ -29,37 +34,12 @@ def gather_numpy(x, index, axis):
return gather
class TestGatherOp(OpTest):
def setUp(self):
self.op_type = "gather"
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
'Index': np.array(self.index).astype(self.index_type)
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float64"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestXPUGatherOp(OpTest):
class TestXPUGatherOp(XPUOpTest):
def setUp(self):
self.dtype = "float32"
self.op_type = "gather"
self.dtype = np.float32
self.use_xpu = True
self.use_mkldnn = False
self.attrs = {'use_xpu': True}
self.config()
......@@ -71,12 +51,12 @@ class TestXPUGatherOp(OpTest):
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self):
if self.dtype == np.float32 and paddle.is_compiled_with_xpu():
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if self.dtype == np.float32 and paddle.is_compiled_with_xpu():
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
......@@ -85,7 +65,7 @@ class TestXPUGatherOp(OpTest):
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = self.dtype
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int32"
......@@ -150,5 +130,14 @@ class TestCase6(TestXPUGatherOp):
self.index_type = "int32"
class TestCase7(TestXPUGatherOp):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': True}
self.x_type = "float32"
self.index = [1, 3]
self.index_type = "int64"
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import sys
sys.path.append("..")
import numpy as np
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
def ref_logsumexp(x, axis=None, keepdim=False, reduce_all=False):
if isinstance(axis, int):
axis = (axis, )
elif isinstance(axis, list):
axis = tuple(axis)
if reduce_all:
axis = None
out = np.log(np.exp(x).sum(axis=axis, keepdims=keepdim))
return out
class XPUTestLogsumexp(XPUOpTest):
def setUp(self):
self.op_type = 'logsumexp'
self.shape = [2, 3, 4, 5]
self.dtype = 'float32'
self.axis = [-1]
self.keepdim = False
self.reduce_all = False
self.set_attrs()
np.random.seed(10)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = ref_logsumexp(x, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {
'axis': self.axis,
'keepdim': self.keepdim,
'reduce_all': self.reduce_all
}
def set_attrs(self):
pass
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
pass
class TestLogsumexp_shape(XPUTestLogsumexp):
def set_attrs(self):
self.shape = [4, 5, 6]
class TestLogsumexp_axis(XPUTestLogsumexp):
def set_attrs(self):
self.axis = [0, -1]
class TestLogsumexp_axis_all(XPUTestLogsumexp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
class TestLogsumexp_keepdim(XPUTestLogsumexp):
def set_attrs(self):
self.keepdim = True
class TestLogsumexp_reduce_all(XPUTestLogsumexp):
def set_attrs(self):
self.reduce_all = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册