未验证 提交 2553af4f 编写于 作者: F furnace 提交者: GitHub

[Phi] mv kernel (#39861)

[Phi] mv kernel 
上级 22f84122
......@@ -12,7 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mv_op.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -116,10 +122,3 @@ REGISTER_OPERATOR(mv, ops::MVOp, ops::MVOpMaker,
ops::MVOpGradMaker<paddle::framework::OpDesc>,
ops::MVOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mv_grad, ops::MVOpGrad);
REGISTER_OP_CPU_KERNEL(
mv, ops::MVKernel<paddle::platform::CPUDeviceContext, float>,
ops::MVKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
mv_grad, ops::MVGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MVGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mv_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MVGradDxCUDAKernel(const int m, const int n, const T *dout,
const T *vec, T *dx) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < m * n; idx += blockDim.x * gridDim.x) {
int i = idx / n;
int j = idx % n;
dx[idx] = dout[i] * vec[j];
}
}
// Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y
// are both matrices.
//
// dX = | dOut Vec^T
// dVec = | X^T dOut
template <typename T>
class MVGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<framework::Tensor>("X");
auto *vec = context.Input<framework::Tensor>("Vec");
auto *dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dvec =
context.Output<framework::Tensor>(framework::GradVarName("Vec"));
auto dim_x = x->dims();
int m = dim_x[0];
int n = dim_x[1];
// get data ptr
const T *x_data = x->data<T>();
const T *vec_data = vec->data<T>();
const T *dout_data = dout->data<T>();
auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
auto stream = context.cuda_device_context().stream();
auto config = GetGpuLaunchConfig1D(dev_ctx, m * n);
if (dx) {
T *dx_data = dx->mutable_data<T>(context.GetPlace());
MVGradDxCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
m, n, dout_data, vec_data, dx_data);
}
if (dvec) {
T *dvec_data = dvec->mutable_data<T>(context.GetPlace());
blas.GEMV(true, dim_x[0], dim_x[1], static_cast<T>(1), x_data, dout_data,
static_cast<T>(0), dvec_data);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
mv, ops::MVKernel<paddle::platform::CUDADeviceContext, float>,
ops::MVKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
mv_grad, ops::MVGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MVGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MVKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<framework::Tensor>("X");
auto *vec = context.Input<framework::Tensor>("Vec");
auto *out = context.Output<framework::Tensor>("Out");
auto dim_x = x->dims();
// get data ptr
const T *x_data = x->data<T>();
const T *vec_data = vec->data<T>();
T *out_data = out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
blas.GEMV(false, dim_x[0], dim_x[1], static_cast<T>(1), x_data, vec_data,
static_cast<T>(0), out_data);
}
};
// Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y
// are both matrices.
//
// dX = | dOut vec^T
// dVec = | X^T dOut
template <typename DeviceContext, typename T>
class MVGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<framework::Tensor>("X");
auto *vec = context.Input<framework::Tensor>("Vec");
auto *dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dvec =
context.Output<framework::Tensor>(framework::GradVarName("Vec"));
auto dim_x = x->dims();
int m = dim_x[0];
int n = dim_x[1];
// get data ptr
const T *x_data = x->data<T>();
const T *vec_data = vec->data<T>();
const T *dout_data = dout->data<T>();
if (dx) {
T *dx_data = dx->mutable_data<T>(context.GetPlace());
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
dx_data[i * n + j] = dout_data[i] * vec_data[j];
}
}
}
if (dvec) {
T *dvec_data = dvec->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
blas.GEMV(true, dim_x[0], dim_x[1], static_cast<T>(1), x_data, dout_data,
static_cast<T>(0), dvec_data);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mv_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
template <typename T, typename Context>
void MvGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& vec,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* vec_grad) {
auto dout = out_grad;
auto dx = x_grad;
auto dvec = vec_grad;
auto dim_x = x.dims();
int m = dim_x[0];
int n = dim_x[1];
// get data ptr
const T* x_data = x.data<T>();
const T* vec_data = vec.data<T>();
const T* dout_data = dout.data<T>();
if (dx) {
T* dx_data = dev_ctx.template Alloc<T>(dx);
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
dx_data[i * n + j] = dout_data[i] * vec_data[j];
}
}
}
if (dvec) {
T* dvec_data = dev_ctx.template Alloc<T>(dvec);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
blas.GEMV(true,
dim_x[0],
dim_x[1],
static_cast<T>(1),
x_data,
dout_data,
static_cast<T>(0),
dvec_data);
}
}
} // namespace phi
PD_REGISTER_KERNEL(mv_grad, CPU, ALL_LAYOUT, phi::MvGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/mv_kernel_impl.h"
PD_REGISTER_KERNEL(mv, CPU, ALL_LAYOUT, phi::MvKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mv_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
template <typename T>
__global__ void MVGradDxCUDAKernel(
const int m, const int n, const T *dout, const T *vec, T *dx) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < m * n; idx += blockDim.x * gridDim.x) {
int i = idx / n;
int j = idx % n;
dx[idx] = dout[i] * vec[j];
}
}
template <typename T, typename Context>
void MvGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &vec,
const DenseTensor &out_grad,
DenseTensor *x_grad,
DenseTensor *vec_grad) {
auto dout = out_grad;
auto dx = x_grad;
auto dvec = vec_grad;
auto dim_x = x.dims();
int m = dim_x[0];
int n = dim_x[1];
// get data ptr
const T *x_data = x.data<T>();
const T *vec_data = vec.data<T>();
const T *dout_data = dout.data<T>();
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
auto stream = dev_ctx.stream();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, m * n);
if (dx) {
T *dx_data = dev_ctx.template Alloc<T>(dx);
MVGradDxCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
m, n, dout_data, vec_data, dx_data);
}
if (dvec) {
T *dvec_data = dev_ctx.template Alloc<T>(dvec);
blas.GEMV(true,
dim_x[0],
dim_x[1],
static_cast<T>(1),
x_data,
dout_data,
static_cast<T>(0),
dvec_data);
}
}
} // namespace phi
PD_REGISTER_KERNEL(mv_grad, GPU, ALL_LAYOUT, phi::MvGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/mv_kernel_impl.h"
PD_REGISTER_KERNEL(mv, GPU, ALL_LAYOUT, phi::MvKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
template <typename T, typename Context>
void MvKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& vec,
DenseTensor* out) {
auto dim_x = x.dims();
// get data ptr
const T* x_data = x.data<T>();
const T* vec_data = vec.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
blas.GEMV(false,
dim_x[0],
dim_x[1],
static_cast<T>(1),
x_data,
vec_data,
static_cast<T>(0),
out_data);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
// Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y
// are both matrices.
//
// dX = | dOut vec^T
// dVec = | X^T dOut
template <typename T, typename Context>
void MvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& vec,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* vec_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& vec,
DenseTensor* out);
} // namepsace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MvOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("mv", {"X", "Vec"}, {}, {"Out"});
}
KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("mv_grad",
{"X", "Vec", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Vec")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(mv, phi::MvOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(mv_grad, phi::MvGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册