提交 27197290 编写于 作者: Y yuyang18

matmul support float16/double

上级 705e7345
......@@ -96,10 +96,22 @@ struct CUBlas<platform::float16> {
reinterpret_cast<__half *>(C), ldc));
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float16 *alpha, const float16 *A, int lda,
long long int strideA, const float16 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const float16 *beta, float16 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(args...));
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda, strideA,
reinterpret_cast<const __half *>(B), ldb, strideB,
reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
ldc, strideC, batchCount));
#else
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
#endif
......
......@@ -33,9 +33,10 @@ template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>;
DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
......
......@@ -25,7 +25,7 @@ namespace operators {
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
*/
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
......@@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned.
*/
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) {
if (y_dim.size() > 1) {
return y_dim;
}
......@@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
template <typename DeviceContext, typename T>
class MatMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& x =
void Compute(const framework::ExecutionContext &context) const override {
auto &x =
detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X");
auto& y =
auto &y =
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y");
auto* out = context.Output<framework::Tensor>("Out");
auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
......@@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> {
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static framework::Tensor FoldInitDims(const framework::Tensor& input) {
static framework::Tensor FoldInitDims(const framework::Tensor &input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
......@@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T>
static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
const framework::Tensor& input) {
static framework::Tensor FoldHeadAndLastDims(const DeviceContext &context,
const framework::Tensor &input) {
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
......@@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorIntoMatrixSequence(
framework::Tensor* x, const math::MatDescriptor& descriptor) {
framework::Tensor *x, const math::MatDescriptor &descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
......@@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence(
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
framework::Tensor* y,
framework::Tensor* out, bool trans_x,
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
framework::Tensor *y,
framework::Tensor *out, bool trans_x,
bool trans_y) {
auto x_dim = RowMatrixFromVector(x->dims());
auto y_dim = ColumnMatrixFromVector(y->dims());
......@@ -177,10 +177,10 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
template <typename DeviceContext, typename T>
class MatMulGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b,
framework::Tensor* out) const {
void MatMul(const framework::ExecutionContext &context,
const framework::Tensor &a, bool trans_a,
const framework::Tensor &b, bool trans_b,
framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
......@@ -188,18 +188,18 @@ class MatMulGradKernel : public framework::OpKernel<T> {
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor& b,
void CalcInputGrad(const framework::ExecutionContext &context,
const framework::Tensor &a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor &b,
bool trans_b, bool is_fold_init_dims_b,
framework::Tensor* out) const {
framework::Tensor *out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto& ctx = context.template device_context<DeviceContext>();
auto &ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
......@@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
}
}
void Compute(const framework::ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext &context) const override {
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y");
......@@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of MatMulOp should not be null.");
PADDLE_ENFORCE(context->HasInput("Y"),
......@@ -322,7 +322,7 @@ class MatMulOp : public framework::OperatorWithKernel {
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MatMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
MatMulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of MatMul op");
AddInput("Y", "The second input of MatMul op");
......@@ -376,7 +376,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
......@@ -402,7 +402,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker {
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* retv = new framework::OpDesc();
auto *retv = new framework::OpDesc();
retv->SetType("matmul_grad");
retv->SetInput("X", Input("X"));
retv->SetInput("Y", Input("Y"));
......@@ -421,15 +421,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
ops::MatMulOpGradMaker);
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
REGISTER_OP_CPU_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>);
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_CPU_KERNEL(
matmul_grad,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>);
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
matmul_grad,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册