未验证 提交 87ee3e4f 编写于 作者: Z Zhangjingyu06 提交者: GitHub

[XPU]add stack_grad op for kunlun2,*test=kunlun (#38674)

* [XPU]add split op for kunlun2,*test=kunlun

* [XPU]add split op for kunlun2,*test=kunlun

* [XPU]add split op for kunlun,*test=kunlun

* [XPU]add stack_grad op for kunlun2,*test=kunlun
Co-authored-by: NQingshuChen <chenqingshu@baidu.com>
上级 0de8a805
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/stack_op.h"
#include <string>
#ifdef PADDLE_WITH_XPU
#include <vector>
#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
......@@ -59,14 +62,44 @@ class StackXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class StackGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dy_dims = dy->dims();
if (axis < 0) axis += dy_dims.size() + 1;
auto dy_shape = framework::vectorize<int>(dy_dims);
std::vector<int> dx_dims_list(dx.size(), 1);
std::vector<T*> dx_lists;
for (auto out : dx) {
dx_lists.push_back(out->mutable_data<T>(ctx.GetPlace()));
}
int r = xpu::split<T>(dev_ctx.x_context(), dy->data<T>(), dx_lists,
dy_shape, dx_dims_list, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"The stack_grad XPU kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(stack,
ops::StackXPUKernel<plat::XPUDeviceContext, int64_t>,
ops::StackXPUKernel<plat::XPUDeviceContext, float>,
ops::StackXPUKernel<plat::XPUDeviceContext, int>,
ops::StackXPUKernel<plat::XPUDeviceContext, float>);
ops::StackXPUKernel<plat::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(stack_grad,
ops::StackGradXPUKernel<plat::XPUDeviceContext, float>,
ops::StackGradXPUKernel<plat::XPUDeviceContext, int>);
#endif
......@@ -300,6 +300,7 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
......@@ -333,6 +333,8 @@ XPUOpMap& get_kl2_ops() {
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -66,6 +66,15 @@ class TestStackOpBase(XPUOpTest):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if self.dtype == 'int64' or self.dtype == 'int32':
pass
else:
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, self.get_x_names(), 'Y')
class TestStackOp1(TestStackOpBase):
def initParameters(self):
......@@ -81,11 +90,17 @@ class TestStackOp3(TestStackOpBase):
def initParameters(self):
self.axis = -1
def test_check_grad(self):
pass
class TestStackOp4(TestStackOpBase):
def initParameters(self):
self.axis = -4
def test_check_grad(self):
pass
class TestStackOp5(TestStackOpBase):
def initParameters(self):
......@@ -113,7 +128,7 @@ class TestStackOpint(TestStackOpBase):
self.num_inputs = 4
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = 'int'
self.dtype = 'int32'
def initParameters(self):
self.num_inputs = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册