未验证 提交 7085cb97 编写于 作者: T taixiurong 提交者: GitHub

xpu-paddlepaddle-40 [任务] fused_gemm_epilogue 支持xpu (#45706)

* add gemm_epilogue

* xpu-paddlepaddle-40 [任务] fused_gemm_epilogue 支持 test=kunlun
上级 ca1cab3e
......@@ -36,6 +36,7 @@ op_library(fusion_lstm_op)
if(WITH_XPU)
op_library(resnet_basic_block_op)
op_library(resnet_unit_op)
op_library(fused_gemm_epilogue_op)
endif()
if(WITH_GPU OR WITH_ROCM)
......
......@@ -392,14 +392,13 @@ class FusedGemmEpilogueOpGradMaker : public framework::SingleGradOpMaker<T> {
protected:
void Apply(GradOpPtr<T> op) const override {
const auto& act_type = this->template Attr<std::string>("activation");
PADDLE_ENFORCE_EQ(
act_type,
"none",
phi::errors::InvalidArgument("The activation should be none."));
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
if (act_type != "none") {
op->SetInput("ReserveSpace", this->Input("ReserveSpace"));
}
op->SetInput("DOut", this->OutputGrad("Out"));
op->SetOutput("DX", this->InputGrad("X"));
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include <cuda_runtime_api.h>
#include <algorithm>
......@@ -321,3 +322,4 @@ class GemmEpilogueAlgoCache {
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class FusedGemmEpilogueXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<phi::XPUContext>();
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor* bias = ctx.Input<Tensor>("Bias");
Tensor* out = ctx.Output<Tensor>("Out");
Tensor* reserve_space = ctx.Output<Tensor>("ReserveSpace");
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
std::string activation = ctx.Attr<std::string>("activation");
VLOG(5) << "trans_x = " << trans_x << " , trans_y = " << trans_y
<< " , activation = " << activation;
auto x_mat_dims =
phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
// (M * K) * (K * N) for new api use
// int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
// int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
// int64_t N = trans_y ? y->dims()[0] : y->dims()[1];
// 调用新接口,这里先分开调用,等待qingpen的新接口
int r = 0;
xpu::Activation_t act = xpu::Activation_t::LINEAR;
if (activation == "relu") {
act = xpu::Activation_t::RELU;
} else if (activation == "gelu") {
act = xpu::Activation_t::GELU;
}
// fc + bias + act
// 1. fc
phi::XpuFcInfo fc_info;
phi::GetFCInfo(x_mat_dims, y->dims(), trans_x, trans_y, &fc_info);
VLOG(0) << "FusedGemmEpilogueXPUKernel 000";
xpu::Context* xpu_ctx = dev_ctx.x_context();
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
XPUType* out_ptr =
reinterpret_cast<XPUType*>(out->mutable_data<T>(ctx.GetPlace()));
xpu::ctx_guard RAII_GUARD(xpu_ctx);
XPUType* fc_out_ptr = RAII_GUARD.alloc_l3_or_gm<XPUType>(out->numel());
phi::MatMulXPUFunction<XPUType>(
xpu_ctx, x_ptr, y_ptr, fc_out_ptr, fc_info, 1.0f);
XPUType* bias_out_ptr = out_ptr;
if (activation != "none" && reserve_space) {
bias_out_ptr = reinterpret_cast<XPUType*>(
reserve_space->mutable_data<T>(ctx.GetPlace()));
}
// 2 bias
const XPUType* bias_ptr = reinterpret_cast<const XPUType*>(bias->data<T>());
r = xpu::broadcast_add(xpu_ctx,
fc_out_ptr,
bias_ptr,
bias_out_ptr,
{fc_info.m, fc_info.n},
{fc_info.n});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
// 3 act
if (activation == "relu") {
r = xpu::relu(xpu_ctx, bias_out_ptr, out_ptr, out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu");
} else if (activation == "gelu") {
r = xpu::gelu(xpu_ctx, bias_out_ptr, out_ptr, out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu");
}
}
};
template <typename DeviceContext, typename T>
class FusedGemmEpilogueXPUGradKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto& dev_ctx = ctx.template device_context<phi::XPUContext>();
const Tensor* dout = ctx.Input<Tensor>("DOut");
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor* reserve_space = ctx.Input<Tensor>("ReserveSpace");
Tensor* dx = ctx.Output<Tensor>("DX");
Tensor* dy = ctx.Output<Tensor>("DY");
Tensor* dbias = ctx.Output<Tensor>("DBias");
std::string activation = "none";
if (ctx.HasAttr("activation")) {
activation = ctx.Attr<std::string>("activation");
} else if (ctx.HasAttr("activation_grad")) {
activation = ctx.Attr<std::string>("activation_grad");
}
auto* xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout->data<T>());
const XPUType* dout_fc_ptr = dout_ptr;
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
// const XPUType*
const XPUType* reserve_space_ptr =
(reserve_space == NULL)
? (reinterpret_cast<const XPUType*>(NULL))
: (reinterpret_cast<const XPUType*>(reserve_space->data<T>()));
XPUType* d_act_input_ptr;
if (activation != "none") {
d_act_input_ptr = RAII_GUARD.alloc_l3_or_gm<XPUType>(dout->numel());
dout_fc_ptr = d_act_input_ptr;
}
// 1. act_grad 2. fc_grad 3. dbias
int r = 0;
if (activation == "relu") {
r = xpu::relu_grad(xpu_ctx,
reserve_space_ptr,
reserve_space_ptr,
dout_ptr,
d_act_input_ptr,
dout->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu_grad");
} else if (activation == "gelu") {
r = xpu::gelu_grad(xpu_ctx,
reserve_space_ptr,
reserve_space_ptr,
dout_ptr,
d_act_input_ptr,
dout->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad");
}
auto x_mat_dims =
phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
phi::XpuFcInfo info_forward;
phi::GetFCInfo(x_mat_dims, y->dims(), trans_x, trans_y, &info_forward);
// 2. fc_grad
const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 =
(dx == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dx->mutable_data<T>(ctx.GetPlace()));
XPUType* c_2 =
(dy == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->mutable_data<T>(ctx.GetPlace()));
phi::XpuFcInfo info_dx;
phi::XpuFcInfo info_dy;
std::tuple<phi::XpuFcInfo,
phi::XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
trans_x,
trans_y,
x_ptr,
y_ptr,
dout_fc_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) {
phi::MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
}
if (dy) {
phi::MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
}
// 3. dbias
if (dbias) {
XPUType* dbias_ptr =
reinterpret_cast<XPUType*>(dbias->mutable_data<T>(ctx.GetPlace()));
r = xpu::reduce_sum(xpu_ctx,
dout_fc_ptr,
dbias_ptr,
{info_forward.m, info_forward.n},
{0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
fused_gemm_epilogue,
ops::FusedGemmEpilogueXPUKernel<phi::XPUContext, float>,
ops::FusedGemmEpilogueXPUKernel<phi::XPUContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
fused_gemm_epilogue_grad,
ops::FusedGemmEpilogueXPUGradKernel<phi::XPUContext, float>,
ops::FusedGemmEpilogueXPUGradKernel<phi::XPUContext,
paddle::platform::float16>);
......@@ -654,7 +654,14 @@ XPUOpMap& get_kl2_ops() {
{"resnet_basic_block_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_basic_block",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}};
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"fused_gemm_epilogue",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"fused_gemm_epilogue_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
};
return s_xpu2_kernels;
}
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
def get_outputs(DOut, X, Y):
DX = np.dot(DOut, Y.T)
DY = np.dot(X.T, DOut)
DBias = np.sum(DOut, axis=0)
return DX, DY, DBias
class XPUTestFuseGemmGradOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'fused_gemm_epilogue_grad'
self.use_dynamic_create_class = False
class TestFuseGemmEpilogueGradOpDXYBias1(XPUOpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "fused_gemm_epilogue_grad"
self.__class__.no_need_check_grad = True
self.dtype = self.in_type
self.init_data()
def init_data(self):
self.inputs = {
'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5,
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
DX, DY, DBias = get_outputs(self.inputs['DOut'], self.inputs['X'],
self.inputs['Y'])
self.outputs = {'DX': DX, 'DY': DY, 'DBias': DBias}
def test_check_output(self):
self.atol = 1e-4
if self.dtype == np.float16:
self.atol = 1e-3
self.check_output_with_place(core.XPUPlace(0), atol=self.atol)
class TestFuseGemmEpilogueGradOpDXYBias2(XPUOpTest):
def init_data(self):
self.inputs = {
'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5,
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
_, DY, DBias = get_outputs(self.inputs['DOut'], self.inputs['X'],
self.inputs['Y'])
self.outputs = {'DY': DY, 'DBias': DBias}
class TestFuseGemmEpilogueGradOpDXYBias3(XPUOpTest):
def init_data(self):
self.inputs = {
'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5,
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
_, DY, _ = get_outputs(self.inputs['DOut'], self.inputs['X'],
self.inputs['Y'])
self.outputs = {'DY': DY}
class TestFuseGemmEpilogueGradOpDXYBias4(XPUOpTest):
def init_data(self):
self.inputs = {
'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5,
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
DX, DY, _ = get_outputs(self.inputs['DOut'], self.inputs['X'],
self.inputs['Y'])
self.outputs = {'DX': DX, 'DY': DY}
support_types = get_xpu_op_support_types('fused_gemm_epilogue_grad')
for stype in support_types:
create_test_class(globals(), XPUTestFuseGemmGradOp, stype)
if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0)
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
def gelu(x):
y_ref = 0.5 * x * (
1.0 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
return y_ref.astype(x.dtype)
def relu(x):
mask = x > 0
return x * mask
def get_output(X, Y, bias, act):
out = np.dot(X, Y) + bias
if act == 'relu':
return relu(out)
elif act == 'gelu':
return gelu(out)
else:
return out
def matmul(x, y, bias, trans_x, trans_y):
x = np.array(x)
if trans_x:
x = np.ascontiguousarray(np.transpose(x))
if trans_y:
y = np.ascontiguousarray(np.transpose(y))
z = np.matmul(x, y)
if bias is None:
return z
else:
return z + bias
def matmul_grad(x, y, bias, dz, trans_x, trans_y):
if trans_x:
if trans_y:
dx = matmul(y, dz, None, True, True)
dy = matmul(dz, x, None, True, True)
else:
dx = matmul(y, dz, None, False, True)
dy = matmul(x, dz, None, False, False)
else:
if trans_y:
dx = matmul(dz, y, None, False, False)
dy = matmul(dz, x, None, True, False)
else:
dx = matmul(dz, y, None, False, True)
dy = matmul(x, dz, None, True, False)
if bias is None:
dbias = None
else:
dbias = np.sum(dz, axis=0, keepdims=False)
return dx, dy, dbias
class XPUTestFuseGemmOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'fused_gemm_epilogue'
self.use_dynamic_create_class = False
class TestFuseGemmBase(XPUOpTest):
def setUp(self):
self.__class__.no_need_check_grad = True
self.op_type = "fused_gemm_epilogue"
self.init_dtype_type()
self.init_datas_shape_and_attrs()
self.inputs = {
'X': np.random.random(self.x_shape).astype(self.dtype) - 0.5,
'Y': np.random.random(self.y_shape).astype(self.dtype) - 0.5,
'Bias':
np.random.random(self.bias_shape).astype(self.dtype) - 0.5
}
if self.trans_x == True:
numpy_input_x = self.inputs['X'].reshape(
(self.x_shape[0], -1)).T
else:
numpy_input_x = self.inputs['X'].reshape((-1, self.x_shape[-1]))
if self.trans_y == True:
numpy_input_y = self.inputs['Y'].T
else:
numpy_input_y = self.inputs['Y']
self.outputs = {
'Out':
get_output(numpy_input_x, numpy_input_y, self.inputs['Bias'],
self.activation).reshape(self.out_shape)
}
self.attrs = {
"activation": self.activation,
"trans_y": self.trans_y,
"trans_x": self.trans_x
}
def init_dtype_type(self):
self.dtype = self.in_type
self.atol = 1e-4
if self.dtype == np.float16:
self.atol = 1e-3
def init_datas_shape_and_attrs(self):
self.x_shape = [8, 4]
self.y_shape = [4, 128]
self.bias_shape = [
128,
]
self.out_shape = [8, 128]
self.activation = "relu"
self.trans_y = False
self.trans_x = False
def test_check_output(self):
self.check_output_with_place(core.XPUPlace(0), atol=self.atol)
class TestFuseGemmEpilogueOp1(TestFuseGemmBase):
def init_datas_shape_and_attrs(self):
self.x_shape = [4, 8]
self.y_shape = [4, 128]
self.bias_shape = [
128,
]
self.out_shape = [8, 128]
self.activation = "relu"
self.trans_y = False
self.trans_x = True
class TestFuseGemmEpilogueOp2(TestFuseGemmBase):
def init_datas_shape_and_attrs(self):
self.x_shape = [8, 4]
self.y_shape = [128, 4]
self.bias_shape = [
128,
]
self.out_shape = [8, 128]
self.activation = "relu"
self.trans_y = True
self.trans_x = False
class TestFuseGemmEpilogueOp3(TestFuseGemmBase):
def init_datas_shape_and_attrs(self):
self.x_shape = [4, 8]
self.y_shape = [128, 4]
self.bias_shape = [
128,
]
self.out_shape = [8, 128]
self.activation = "relu"
self.trans_y = True
self.trans_x = True
class TestFuseGemmEpilogueOp4(TestFuseGemmBase):
def init_datas_shape_and_attrs(self):
self.x_shape = [2, 2, 8, 4]
self.y_shape = [4, 128]
self.bias_shape = [
128,
]
self.out_shape = [2, 2, 8, 128]
self.activation = "relu"
self.trans_y = False
self.trans_x = False
class TestFuseGemmEpilogueOp5(TestFuseGemmBase):
def init_datas_shape_and_attrs(self):
self.x_shape = [4, 2, 2, 8]
self.y_shape = [4, 128]
self.bias_shape = [
128,
]
self.out_shape = [2, 2, 8, 128]
self.activation = "relu"
self.trans_y = False
self.trans_x = True
class TestFuseGemmEpilogueOp6(TestFuseGemmBase):
def init_datas_shape_and_attrs(self):
self.x_shape = [8, 4]
self.y_shape = [4, 128]
self.bias_shape = [
128,
]
self.out_shape = [8, 128]
self.activation = "gelu"
self.trans_y = False
self.trans_x = False
class TestFuseGemmEpilogueOp7(TestFuseGemmBase):
def init_datas_shape_and_attrs(self):
self.x_shape = [8, 4]
self.y_shape = [4, 128]
self.bias_shape = [
128,
]
self.out_shape = [8, 128]
self.activation = "none"
self.trans_y = False
self.trans_x = False
class TestEagerFusedGemmEpilogue(unittest.TestCase):
def setUp(self):
paddle.set_device('xpu')
def test_case_act(self):
paddle.disable_static()
x_np = np.random.random((8, 4)).astype(np.float32) - 0.5
y_np = np.random.random((4, 128)).astype(np.float32) - 0.5
bias_np = np.random.random((128, )).astype(np.float32) - 0.5
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
bias = paddle.to_tensor(bias_np)
x.stop_gradient = False
y.stop_gradient = False
out1 = core.ops.fused_gemm_epilogue(x, y, bias, 'trans_x', False,
'trans_y', False, 'activation',
'none')
out2 = core.ops.fused_gemm_epilogue(x, y, bias, 'trans_x', False,
'trans_y', False, 'activation',
'relu')
out3 = core.ops.fused_gemm_epilogue(x, y, bias, 'trans_x', False,
'trans_y', False, 'activation',
'gelu')
out_np1 = get_output(x_np, y_np, bias_np, 'none')
out_np2 = get_output(x_np, y_np, bias_np, 'relu')
out_np3 = get_output(x_np, y_np, bias_np, 'gelu')
np.testing.assert_allclose(out1, out_np1, atol=1e-04)
np.testing.assert_allclose(out2, out_np2, atol=1e-04)
np.testing.assert_allclose(out3, out_np3, atol=1e-03)
out_grad_np1 = np.random.randint(low=-20, high=20,
size=out_np1.shape).astype(np.float32)
paddle.autograd.backward(out1,
grad_tensors=[paddle.to_tensor(out_grad_np1)])
x_grad_np, y_grad_np, bias_grad_np = matmul_grad(
x_np, y_np, bias_np, out_grad_np1, False, False)
np.testing.assert_allclose(x.grad.numpy(), x_grad_np, atol=1e-02)
self.assertEqual(y_grad_np.shape, y_np.shape)
np.testing.assert_allclose(y.grad.numpy(), y_grad_np, atol=1e-03)
paddle.enable_static()
support_types = get_xpu_op_support_types('fused_gemm_epilogue')
for stype in support_types:
create_test_class(globals(), XPUTestFuseGemmOp, stype)
if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册