未验证 提交 067107ad 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

support log_grad op, *test=kunlun (#44662)

上级 acde295c
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -358,6 +358,52 @@ struct XPULeakyReluGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct XPULogGradFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
const auto *x = ctx.Input<Tensor>("X");
auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
const T *x_data = nullptr;
const T *y_grad = nullptr;
if (x != nullptr) x_data = x->data<T>();
if (dOut != nullptr) y_grad = dOut->data<T>();
T *x_grad = dX->mutable_data<T>(ctx.GetPlace());
auto dev_ctx =
ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
const auto x_dims = x->dims();
auto xshape = phi::vectorize<int>(x_dims);
int len = x->dims()[x_dims.size() - 1];
std::vector<int> yshape(1, len);
xpu::ctx_guard RAII_GUARD(dev_ctx);
T *y_data = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(y_data);
T *tmp_grad = RAII_GUARD.alloc_l3_or_gm<T>(x->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(tmp_grad);
int r = xpu::constant<T>(dev_ctx, y_data, len, static_cast<T>(1.0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
// dx.device(d) = dout * (static_cast<T>(1) / x);
r = xpu::broadcast_div(dev_ctx,
reinterpret_cast<const float *>(y_data),
reinterpret_cast<const float *>(x_data),
reinterpret_cast<float *>(tmp_grad),
yshape,
xshape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_div");
r = xpu::broadcast_mul(dev_ctx,
reinterpret_cast<const float *>(y_grad),
reinterpret_cast<const float *>(tmp_grad),
reinterpret_cast<float *>(x_grad),
xshape,
xshape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul");
}
};
template <typename T>
struct XPUPowFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -584,5 +630,6 @@ REGISTER_OP_XPU_KERNEL(exp,
ops::XPUActivationKernel<ops::XPUExpFunctor<float>>);
REGISTER_OP_XPU_KERNEL(log,
ops::XPUActivationKernel<ops::XPULogFunctor<float>>);
REGISTER_OP_XPU_KERNEL(
log_grad, ops::XPUActivationGradKernel<ops::XPULogGradFunctor<float>>);
#endif // PADDLE_WITH_XPU
......@@ -276,6 +276,7 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
......@@ -34,12 +34,15 @@ class TestActivationOPBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.set_shape()
self.set_case()
def set_shape(self):
self.shape = [11, 17]
def set_case(self):
self.op_type = 'exp'
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = np.exp(x)
self.attrs = {'use_xpu': True}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
......@@ -313,14 +316,33 @@ class XPUTestLogOP(XPUOpTestWrapper):
def set_case(self):
self.op_type = "log"
self.dtype = self.in_type
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.log(x)
self.attrs = {'use_xpu': True}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
class TestLogCase1(XPUTestLog):
def set_shape(self):
self.shape = [1, 11, 17]
class TestLogCase2(XPUTestLog):
def set_shape(self):
self.shape = [2, 2, 2]
class TestLogCase3(XPUTestLog):
def set_shape(self):
self.shape = [2]
class TestLogCase4(XPUTestLog):
def set_shape(self):
self.shape = [1, 2, 3, 4]
support_types = get_xpu_op_support_types('log')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册