未验证 提交 923594de 编写于 作者: W Weilong Wu 提交者: GitHub

[XPU] migrate mul to phi (#45502)

* [XPU] migrate mul to phi;test=kunlun

* rm fluid mul xpu op;test=kunlun
上级 0710f058
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/xpu_api_wrapper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
template <typename DeviceContext, typename T>
class MulXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y");
Tensor* z = context.Output<Tensor>("Out");
const Tensor x_matrix =
x->dims().size() > 2
? framework::ReshapeToMatrix(
*x, context.template Attr<int>("x_num_col_dims"))
: *x;
const Tensor y_matrix =
y->dims().size() > 2
? framework::ReshapeToMatrix(
*y, context.template Attr<int>("y_num_col_dims"))
: *y;
z->mutable_data<T>(context.GetPlace());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x_matrix.data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y_matrix.data<T>());
XPUType* out_ptr = reinterpret_cast<XPUType*>(z->data<T>());
bool trans_a = false;
bool trans_b = false;
auto x_dims = x_matrix.dims();
auto y_dims = y_matrix.dims();
phi::XpuFcInfo fc_info;
phi::GetFCInfo(x_dims, y_dims, trans_a, trans_b, &fc_info);
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
xpu::Context* xpu_ctx = dev_ctx.x_context();
phi::MatMulXPUFunction<XPUType>(
xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f);
}
};
template <typename DeviceContext, typename T>
class MulGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor&>(*y);
auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
Tensor dout_mat;
dout_mat.Resize({phi::flatten_to_2d(x->dims(), x_num_col_dims)[0],
phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::XpuFcInfo info_forward;
phi::GetFCInfo(
x_matrix.dims(), y_matrix.dims(), false, false, &info_forward);
const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout->data<T>());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
xpu::Context* xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
// begin calculate
const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 =
(dx == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dx->mutable_data<T>(ctx.GetPlace()));
XPUType* c_2 =
(dy == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->mutable_data<T>(ctx.GetPlace()));
phi::XpuFcInfo info_dx;
phi::XpuFcInfo info_dy;
std::tuple<phi::XpuFcInfo,
phi::XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
false,
false,
x_ptr,
y_ptr,
dout_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) {
phi::MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
}
if (dy) {
phi::MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
mul,
ops::MulXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MulXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>);
REGISTER_OP_XPU_KERNEL(
mul_grad,
ops::MulGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MulGradXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>)
#endif
......@@ -13,11 +13,11 @@
// limitations under the License.
#include "paddle/phi/kernels/matmul_grad_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi {
template <typename T, typename Context>
......@@ -81,6 +81,82 @@ void MatmulGradKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void MatmulWithFlattenGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor* x_grad,
DenseTensor* y_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto x_matrix = x.dims().size() > 2
? paddle::framework::ReshapeToMatrix(x, x_num_col_dims)
: static_cast<const DenseTensor&>(x);
auto y_matrix = y.dims().size() > 2
? paddle::framework::ReshapeToMatrix(y, y_num_col_dims)
: static_cast<const DenseTensor&>(y);
DenseTensor dout_mat;
dout_mat.Resize({phi::flatten_to_2d(x.dims(), x_num_col_dims)[0],
phi::flatten_to_2d(y.dims(), y_num_col_dims)[1]});
if (x_grad != nullptr) {
x_grad->set_lod(x.lod());
}
if (y_grad != nullptr) {
y_grad->set_lod(y.lod());
}
phi::XpuFcInfo info_forward;
phi::GetFCInfo(x_matrix.dims(), y_matrix.dims(), false, false, &info_forward);
const XPUType* dout_ptr =
reinterpret_cast<const XPUType*>(out_grad.data<T>());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
xpu::Context* xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
// begin calculate
const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 =
(x_grad == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(x_grad));
XPUType* c_2 =
(y_grad == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(y_grad));
phi::XpuFcInfo info_dx;
phi::XpuFcInfo info_dy;
std::tuple<phi::XpuFcInfo,
phi::XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
false,
false,
x_ptr,
y_ptr,
dout_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (x_grad) {
phi::MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
}
if (y_grad) {
phi::MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
}
}
} // namespace phi
PD_REGISTER_KERNEL(matmul_grad,
......@@ -89,3 +165,10 @@ PD_REGISTER_KERNEL(matmul_grad,
phi::MatmulGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_grad,
XPU,
ALL_LAYOUT,
phi::MatmulWithFlattenGradKernel,
float,
phi::dtype::float16) {}
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -42,7 +43,50 @@ void MatmulKernel(const Context& dev_ctx,
MatMulXPUFunction<XPUType>(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f);
}
template <typename T, typename Context>
void MatmulWithFlattenKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const DenseTensor x_matrix =
x.dims().size() > 2
? paddle::framework::ReshapeToMatrix(x, x_num_col_dims)
: x;
const DenseTensor y_matrix =
y.dims().size() > 2
? paddle::framework::ReshapeToMatrix(y, y_num_col_dims)
: y;
dev_ctx.template Alloc<T>(out);
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x_matrix.data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y_matrix.data<T>());
XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());
bool trans_a = false;
bool trans_b = false;
auto x_dims = x_matrix.dims();
auto y_dims = y_matrix.dims();
phi::XpuFcInfo fc_info;
phi::GetFCInfo(x_dims, y_dims, trans_a, trans_b, &fc_info);
xpu::Context* xpu_ctx = dev_ctx.x_context();
phi::MatMulXPUFunction<XPUType>(
xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f);
}
} // namespace phi
PD_REGISTER_KERNEL(
matmul, XPU, ALL_LAYOUT, phi::MatmulKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten,
XPU,
ALL_LAYOUT,
phi::MatmulWithFlattenKernel,
float,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册