提交 27197290 编写于 作者: Y yuyang18

matmul support float16/double

上级 705e7345
...@@ -96,10 +96,22 @@ struct CUBlas<platform::float16> { ...@@ -96,10 +96,22 @@ struct CUBlas<platform::float16> {
reinterpret_cast<__half *>(C), ldc)); reinterpret_cast<__half *>(C), ldc));
} }
template <typename... ARGS> static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa,
static void GEMM_BATCH(ARGS... args) { cublasOperation_t transb, int m, int n, int k,
const float16 *alpha, const float16 *A, int lda,
long long int strideA, const float16 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const float16 *beta, float16 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(args...)); PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda, strideA,
reinterpret_cast<const __half *>(B), ldb, strideB,
reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
ldc, strideC, batchCount));
#else #else
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5"); PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
#endif #endif
......
...@@ -35,7 +35,8 @@ template struct SetConstant<platform::CUDADeviceContext, bool>; ...@@ -35,7 +35,8 @@ template struct SetConstant<platform::CUDADeviceContext, bool>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>;
DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2); DEFINE_GPU_TRANS(2);
......
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned. * original x_dim is returned.
*/ */
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) {
if (x_dim.size() > 1) { if (x_dim.size() > 1) {
return x_dim; return x_dim;
} }
...@@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { ...@@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned. * original y_dim is returned.
*/ */
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) {
if (y_dim.size() > 1) { if (y_dim.size() > 1) {
return y_dim; return y_dim;
} }
...@@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { ...@@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulKernel : public framework::OpKernel<T> { class MatMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto& x = auto &x =
detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X"); detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X");
auto& y = auto &y =
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y"); detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context); auto blas = math::GetBlas<DeviceContext, T>(context);
...@@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> { ...@@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> {
// Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
static framework::Tensor FoldInitDims(const framework::Tensor& input) { static framework::Tensor FoldInitDims(const framework::Tensor &input) {
auto output = input; auto output = input;
auto in_dims = input.dims(); auto in_dims = input.dims();
if (in_dims.size() == 3) { if (in_dims.size() == 3) {
...@@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) { ...@@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
// (Warning: This requires transposing data and writes into new memory.) // (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, static framework::Tensor FoldHeadAndLastDims(const DeviceContext &context,
const framework::Tensor& input) { const framework::Tensor &input) {
auto in_dims = input.dims(); auto in_dims = input.dims();
if (in_dims.size() != 3) { if (in_dims.size() != 3) {
return input; return input;
...@@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, ...@@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
* If transposed, `H,W` will be swapped. * If transposed, `H,W` will be swapped.
*/ */
static void ReshapeTensorIntoMatrixSequence( static void ReshapeTensorIntoMatrixSequence(
framework::Tensor* x, const math::MatDescriptor& descriptor) { framework::Tensor *x, const math::MatDescriptor &descriptor) {
int64_t h, w; int64_t h, w;
h = descriptor.height_; h = descriptor.height_;
w = descriptor.width_; w = descriptor.width_;
...@@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence( ...@@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence(
* If any of `X` and `Y` has batch size BatchSize, the out will have the * If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize. * BatchSize.
*/ */
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
framework::Tensor* y, framework::Tensor *y,
framework::Tensor* out, bool trans_x, framework::Tensor *out, bool trans_x,
bool trans_y) { bool trans_y) {
auto x_dim = RowMatrixFromVector(x->dims()); auto x_dim = RowMatrixFromVector(x->dims());
auto y_dim = ColumnMatrixFromVector(y->dims()); auto y_dim = ColumnMatrixFromVector(y->dims());
...@@ -177,10 +177,10 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, ...@@ -177,10 +177,10 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulGradKernel : public framework::OpKernel<T> { class MatMulGradKernel : public framework::OpKernel<T> {
public: public:
void MatMul(const framework::ExecutionContext& context, void MatMul(const framework::ExecutionContext &context,
const framework::Tensor& a, bool trans_a, const framework::Tensor &a, bool trans_a,
const framework::Tensor& b, bool trans_b, const framework::Tensor &b, bool trans_b,
framework::Tensor* out) const { framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context); auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
...@@ -188,18 +188,18 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -188,18 +188,18 @@ class MatMulGradKernel : public framework::OpKernel<T> {
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
} }
void CalcInputGrad(const framework::ExecutionContext& context, void CalcInputGrad(const framework::ExecutionContext &context,
const framework::Tensor& a, bool trans_a, const framework::Tensor &a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor& b, bool is_fold_init_dims_a, const framework::Tensor &b,
bool trans_b, bool is_fold_init_dims_b, bool trans_b, bool is_fold_init_dims_b,
framework::Tensor* out) const { framework::Tensor *out) const {
if (out == nullptr) return; if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2; out->dims().size() == 2;
if (!need_combine) { if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out); MatMul(context, a, trans_a, b, trans_b, out);
} else { } else {
auto& ctx = context.template device_context<DeviceContext>(); auto &ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a MatMul(context, is_fold_init_dims_a
? FoldInitDims(a) ? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a), : FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
...@@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} }
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto x = *context.Input<framework::Tensor>("X"); auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y"); auto y = *context.Input<framework::Tensor>("Y");
auto dout = auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out")); *context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X")); auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y")); auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
bool transpose_x = context.Attr<bool>("transpose_X"); bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y"); bool transpose_y = context.Attr<bool>("transpose_Y");
...@@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* context) const override { void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"), PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of MatMulOp should not be null."); "Input(X) of MatMulOp should not be null.");
PADDLE_ENFORCE(context->HasInput("Y"), PADDLE_ENFORCE(context->HasInput("Y"),
...@@ -322,7 +322,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -322,7 +322,7 @@ class MatMulOp : public framework::OperatorWithKernel {
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MatMulOpMaker(OpProto* proto, OpAttrChecker* op_checker) MatMulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of MatMul op"); AddInput("X", "The first input of MatMul op");
AddInput("Y", "The second input of MatMul op"); AddInput("Y", "The second input of MatMul op");
...@@ -376,7 +376,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel { ...@@ -376,7 +376,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* context) const override { void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
...@@ -402,7 +402,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -402,7 +402,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker {
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto* retv = new framework::OpDesc(); auto *retv = new framework::OpDesc();
retv->SetType("matmul_grad"); retv->SetType("matmul_grad");
retv->SetInput("X", Input("X")); retv->SetInput("X", Input("X"));
retv->SetInput("Y", Input("Y")); retv->SetInput("Y", Input("Y"));
...@@ -421,15 +421,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker, ...@@ -421,15 +421,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
ops::MatMulOpGradMaker); ops::MatMulOpGradMaker);
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad); REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>); matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
matmul_grad, matmul_grad,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>); ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>); matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
matmul_grad, matmul_grad,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>); ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册