未验证 提交 ab631be9 编写于 作者: J Juncheng 提交者: GitHub

MatmulPrimitive interface (#6462)

* matmul api

* BlasTransposeType

* BatchMatmul/BroadcastMatmul

* fix

* fix

* enum=>enum class
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 027c4793
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_BATCH_MATMUL_H_
#define ONEFLOW_CORE_PRIMITIVE_BATCH_MATMUL_H_
#include "oneflow/core/primitive/include/primitive.h"
#include "oneflow/core/primitive/include/blas.h"
#include "oneflow/core/common/scalar.h"
namespace oneflow {
namespace primitive {
class BatchMatmul : public Primitive {
public:
OF_DISALLOW_COPY_AND_MOVE(BatchMatmul);
BatchMatmul() = default;
~BatchMatmul() override = default;
virtual DataType a_type() const = 0;
virtual DataType b_type() const = 0;
virtual DataType c_type() const = 0;
virtual BlasTransposeType transpose_a() const = 0;
virtual BlasTransposeType transpose_b() const = 0;
virtual void Launch(StreamContext* stream_ctx, size_t num_batches, size_t m, size_t n, size_t k,
Scalar alpha, const void* a, const void* b, Scalar beta, void* c) = 0;
};
class BatchMatmulFactory : public Factory<BatchMatmul> {
public:
OF_DISALLOW_COPY_AND_MOVE(BatchMatmulFactory);
BatchMatmulFactory() = default;
~BatchMatmulFactory() override = default;
virtual std::unique_ptr<BatchMatmul> New(DataType data_type, BlasTransposeType transpose_a,
BlasTransposeType transpose_b) = 0;
};
} // namespace primitive
} // namespace oneflow
#endif // ONEFLOW_CORE_PRIMITIVE_BATCH_MATMUL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_BLAS_H_
#define ONEFLOW_CORE_PRIMITIVE_BLAS_H_
namespace oneflow {
namespace primitive {
enum class BlasTransposeType {
N = 0,
T,
};
} // namespace primitive
} // namespace oneflow
#endif // ONEFLOW_CORE_PRIMITIVE_BLAS_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_BROADCAST_MATMUL_H_
#define ONEFLOW_CORE_PRIMITIVE_BROADCAST_MATMUL_H_
#include "oneflow/core/primitive/include/primitive.h"
#include "oneflow/core/primitive/include/blas.h"
#include "oneflow/core/common/scalar.h"
namespace oneflow {
namespace primitive {
class BroadcastMatmul : public Primitive {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmul);
BroadcastMatmul() = default;
~BroadcastMatmul() override = default;
virtual DataType a_type() const = 0;
virtual DataType b_type() const = 0;
virtual DataType c_type() const = 0;
virtual BlasTransposeType transpose_a() const = 0;
virtual BlasTransposeType transpose_b() const = 0;
virtual void Launch(StreamContext* stream_ctx, Scalar alpha, size_t num_a_dims, int64_t* a_dims,
const void* a, size_t num_b_dims, int64_t* b_dims, const void* b, Scalar beta,
void* c) = 0;
};
class BroadcastMatmulFactory : public Factory<BroadcastMatmul> {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactory);
BroadcastMatmulFactory() = default;
~BroadcastMatmulFactory() override = default;
virtual std::unique_ptr<BroadcastMatmulFactory> New(DataType data_type,
BlasTransposeType transpose_a,
BlasTransposeType transpose_b,
size_t max_num_dims) = 0;
};
} // namespace primitive
} // namespace oneflow
#endif // ONEFLOW_CORE_PRIMITIVE_BROADCAST_MATMUL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_MATMUL_H_
#define ONEFLOW_CORE_PRIMITIVE_MATMUL_H_
#include "oneflow/core/primitive/include/primitive.h"
#include "oneflow/core/primitive/include/blas.h"
#include "oneflow/core/common/scalar.h"
namespace oneflow {
namespace primitive {
class Matmul : public Primitive {
public:
OF_DISALLOW_COPY_AND_MOVE(Matmul);
Matmul() = default;
~Matmul() override = default;
virtual DataType a_type() const = 0;
virtual DataType b_type() const = 0;
virtual DataType c_type() const = 0;
virtual BlasTransposeType transpose_a() const = 0;
virtual BlasTransposeType transpose_b() const = 0;
virtual void Launch(StreamContext* stream_ctx, size_t m, size_t n, size_t k, Scalar alpha,
const void* a, const void* b, Scalar beta, void* c) = 0;
};
class MatmulFactory : public Factory<Matmul> {
public:
OF_DISALLOW_COPY_AND_MOVE(MatmulFactory);
MatmulFactory() = default;
~MatmulFactory() override = default;
virtual std::unique_ptr<Matmul> New(DataType data_type, BlasTransposeType transpose_a,
BlasTransposeType transpose_b) = 0;
};
} // namespace primitive
} // namespace oneflow
#endif // ONEFLOW_CORE_PRIMITIVE_MATMUL_H_
......@@ -22,7 +22,7 @@ namespace oneflow {
namespace primitive {
enum MemcpyKind {
enum class MemcpyKind {
kAuto = 0,
kHtoD,
kDtoH,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册