未验证 提交 55f6fb3d 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI] Migrate mul_grad kernel (#48061)

* cleanup unused code

* unify is_int8 is_bfloat16

* Simplify matmul_v2 FWD kernel

* remove RunKernel methods

* remove import namespace

* remove headers

* clean fluid/phi cross imports

* remove fluid axpy_handler

* delete fluid methods

* activations

* OneDNNMemDesc

* MKLDNNFormatForSize

* MatchShapeToLayout

* MKLDNNMemoryFormat

* MKLDNNFormat

* ReorderMKLDNNHandler

* to_void_cast

* review suggestions

* interpolate

* remove fluid depedency

* init

* ExecuteMatMulV2

* rm fluid kernel

* matmul_grad

* remove mutable_data

* mul_grad
上级 02dfd18d
......@@ -489,83 +489,6 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
}
};
template <typename XT>
class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
private:
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<LoDTensor>("X");
const auto *y = ctx.Input<LoDTensor>("Y");
const auto *dout =
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<LoDTensor>(framework::GradVarName("Y"));
int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");
const Tensor x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor &>(*x);
const Tensor y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor &>(*y);
Tensor dout_matrix = *dout;
dout_matrix.Resize({phi::flatten_to_2d(x->dims(), x_num_col_dims)[0],
phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
// adding mb dim because MatMulV2 handler needs it
std::vector<int64_t> x_dims(3, 1);
std::vector<int64_t> y_dims(3, 1);
std::vector<int64_t> dout_dims(3, 1);
x_dims[1] = x_matrix.dims()[0];
x_dims[2] = x_matrix.dims()[1];
y_dims[1] = y_matrix.dims()[0];
y_dims[2] = y_matrix.dims()[1];
dout_dims[1] = dout_matrix.dims()[0];
dout_dims[2] = dout_matrix.dims()[1];
if (dx != nullptr) {
dx->set_lod(x->lod());
this->ExecuteMatMul(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
&dout_matrix,
dout_dims,
false,
&y_matrix,
y_dims,
true,
static_cast<Tensor *>(dx));
}
if (dy != nullptr) {
dy->set_lod(y->lod());
this->ExecuteMatMul(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
&x_matrix,
x_dims,
true,
&dout_matrix,
dout_dims,
false,
static_cast<Tensor *>(dy));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -578,9 +501,3 @@ REGISTER_OP_KERNEL(mul,
ops::MulMKLDNNINT8Kernel<int8_t>,
ops::MulMKLDNNKernel<paddle::platform::bfloat16>,
ops::MulMKLDNNKernel<float>);
REGISTER_OP_KERNEL(mul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::MulGradMKLDNNKernel<paddle::platform::bfloat16>,
ops::MulGradMKLDNNKernel<float>);
......@@ -1912,6 +1912,47 @@ class MatmulOneDNNHandler
}
};
template <typename T>
static void ExecuteMul(const OneDNNContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
bool trans_x,
bool trans_y,
DenseTensor* out) {
static const std::vector<int64_t> vec_placeholder;
MatmulOneDNNHandler<T, T, T> handler(dev_ctx,
x_dims,
y_dims,
trans_x,
trans_y,
vec_placeholder,
vec_placeholder,
false);
const auto src_memory_p = handler.AcquireSrcMemory(&x);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y);
const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
// This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
template <typename T, typename T_out>
void ExecuteMatmul(const OneDNNContext& dev_ctx,
const DenseTensor& x,
......
......@@ -153,6 +153,49 @@ void MatmulGradKernel(const Context &dev_ctx,
dy->Resize(y.dims());
}
template <typename T, typename Context>
void MatmulWithFlattenGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out_grad,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor *x_grad,
DenseTensor *y_grad) {
const DenseTensor reshaped_y =
paddle::framework::ReshapeToMatrix(y, y_num_col_dims);
const DenseTensor reshaped_x =
paddle::framework::ReshapeToMatrix(x, x_num_col_dims);
const DenseTensor x_matrix = x.dims().size() > 2 ? reshaped_x : x;
const DenseTensor y_matrix = y.dims().size() > 2 ? reshaped_y : y;
DenseTensor dout_matrix = out_grad;
dout_matrix.Resize({flatten_to_2d(x.dims(), x_num_col_dims)[0],
flatten_to_2d(y.dims(), y_num_col_dims)[1]});
// adding mb dim because MatMulV2 handler needs it
std::vector<int64_t> x_dims(3, 1);
std::vector<int64_t> y_dims(3, 1);
std::vector<int64_t> dout_dims(3, 1);
x_dims[1] = x_matrix.dims()[0];
x_dims[2] = x_matrix.dims()[1];
y_dims[1] = y_matrix.dims()[0];
y_dims[2] = y_matrix.dims()[1];
dout_dims[1] = dout_matrix.dims()[0];
dout_dims[2] = dout_matrix.dims()[1];
if (x_grad != nullptr) {
x_grad->set_lod(x.lod());
funcs::ExecuteMul<T>(
dev_ctx, dout_matrix, y_matrix, dout_dims, y_dims, false, true, x_grad);
}
if (y_grad != nullptr) {
y_grad->set_lod(y.lod());
funcs::ExecuteMul<T>(
dev_ctx, x_matrix, dout_matrix, x_dims, dout_dims, true, false, y_grad);
}
}
} // namespace phi
PD_REGISTER_KERNEL(matmul_grad,
......@@ -161,3 +204,10 @@ PD_REGISTER_KERNEL(matmul_grad,
phi::MatmulGradKernel,
float,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_grad,
OneDNN,
ONEDNN,
phi::MatmulWithFlattenGradKernel,
float,
phi::dtype::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册