未验证 提交 86f77141 编写于 作者: J Juncheng 提交者: GitHub

Add cast primitive (#6234)

* Add cast primitive

* fix
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 7bb010e7
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_CAST_H_
#define ONEFLOW_CORE_PRIMITIVE_CAST_H_
#include "oneflow/core/primitive/primitive.h"
namespace oneflow {
namespace primitive {
class Cast : public Primitive {
public:
OF_DISALLOW_COPY_AND_MOVE(Cast);
Cast() = default;
~Cast() override = default;
virtual void Launch(StreamContext* stream_ctx, const void* from, void* to, size_t count) = 0;
};
class CastFactory : public Factory<Cast> {
public:
OF_DISALLOW_COPY_AND_MOVE(CastFactory);
CastFactory() = default;
virtual ~CastFactory() = default;
virtual std::unique_ptr<Cast> New(DataType from, DataType to) = 0;
};
} // namespace primitive
} // namespace oneflow
#endif // ONEFLOW_CORE_PRIMITIVE_CAST_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/primitive/cast.h"
#include "oneflow/core/primitive/cpu/type_seq.h"
namespace oneflow {
namespace primitive {
namespace {
template<typename From, typename To>
void CastCpu(const From* from, To* to, size_t count) {
for (size_t i = 0; i < count; ++i) { to[i] = static_cast<To>(from[i]); }
}
template<typename From, typename To>
class CastImpl : public Cast {
public:
OF_DISALLOW_COPY_AND_MOVE(CastImpl);
CastImpl() = default;
~CastImpl() = default;
void Launch(StreamContext* stream_ctx, const void* from, void* to, size_t count) override {
CastCpu(reinterpret_cast<const From*>(from), reinterpret_cast<To*>(to), count);
}
};
template<typename From, typename To>
std::unique_ptr<Cast> NewCast() {
return std::unique_ptr<Cast>(new CastImpl<From, To>());
}
class CastFactoryImpl : public CastFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl);
CastFactoryImpl() = default;
~CastFactoryImpl() override = default;
std::unique_ptr<Cast> New(DataType from, DataType to) override {
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
{std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \
NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},
static const std::map<std::pair<DataType, DataType>, std::function<std::unique_ptr<Cast>()>>
new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_CAST_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ, CPU_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_CAST_ENTRY
const auto it = new_cast_handle.find(std::make_pair(from, to));
if (it != new_cast_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, CastFactory, CastFactoryImpl);
} // namespace
} // namespace primitive
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_CPU_TYPE_SEQ_H_
#define ONEFLOW_CORE_PRIMITIVE_CPU_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type.h"
#include <half.hpp>
#define CPU_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CPU_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CPU_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CPU_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CPU_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CPU_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)
#define CPU_PRIMITIVE_NATIVE_TYPE_SEQ \
CPU_PRIMITIVE_CHAR_TYPE_SEQ \
CPU_PRIMITIVE_INT8_TYPE_SEQ \
CPU_PRIMITIVE_UINT8_TYPE_SEQ \
CPU_PRIMITIVE_INT32_TYPE_SEQ \
CPU_PRIMITIVE_INT64_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ
#define CPU_PRIMITIVE_ALL_TYPE_SEQ \
CPU_PRIMITIVE_NATIVE_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ
#endif // ONEFLOW_CORE_PRIMITIVE_CPU_TYPE_SEQ_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/primitive/cast.h"
#include "oneflow/core/primitive/cuda/type_seq.h"
#include "oneflow/core/cuda/elementwise.cuh"
#include "oneflow/core/stream/cuda_stream_context.h"
namespace oneflow {
namespace primitive {
namespace {
template<typename To, typename From, typename = void>
struct CastFunctor {
__device__ To operator()(From from) const { return static_cast<To>(from); }
};
template<typename To>
struct CastFunctor<To, half, typename std::enable_if<!std::is_same<To, half>::value>::type> {
__device__ To operator()(half from) const { return static_cast<To>(static_cast<float>(from)); }
};
template<typename From>
struct CastFunctor<half, From, typename std::enable_if<!std::is_same<From, half>::value>::type> {
__device__ half operator()(From from) const {
return static_cast<half>(static_cast<float>(from));
}
};
#if CUDA_VERSION >= 11000
template<typename To>
struct CastFunctor<To, nv_bfloat16,
typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value
|| std::is_same<To, half>::value)>::type> {
__device__ To operator()(nv_bfloat16 from) const {
return static_cast<To>(static_cast<float>(from));
}
};
template<typename From>
struct CastFunctor<nv_bfloat16, From,
typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value
|| std::is_same<From, half>::value)>::type> {
__device__ nv_bfloat16 operator()(From from) const {
return static_cast<nv_bfloat16>(static_cast<float>(from));
}
};
#endif // CUDA_VERSION >= 11000
template<typename From, typename To>
void LaunchCast(cudaStream_t stream, const void* from, void* to, size_t count) {
OF_CUDA_CHECK((cuda::elementwise::Unary<CastFunctor<To, From>, To, From>(
CastFunctor<To, From>(), count, reinterpret_cast<To*>(to),
reinterpret_cast<const From*>(from), stream)));
}
using LaunchFn = std::function<void(cudaStream_t /*stream*/, const void* /*from*/, void* /*to*/,
size_t /*count*/)>;
class CastImpl : public Cast {
public:
OF_DISALLOW_COPY_AND_MOVE(CastImpl);
explicit CastImpl(LaunchFn launch_fn) : launch_fn_(std::move(launch_fn)) {}
~CastImpl() = default;
void Launch(StreamContext* stream_ctx, const void* from, void* to, size_t count) override {
auto* cuda_stream_ctx = CHECK_NOTNULL(dynamic_cast<CudaStreamContext*>(stream_ctx));
launch_fn_(cuda_stream_ctx->cuda_stream(), from, to, count);
}
private:
LaunchFn launch_fn_;
};
template<typename From, typename To>
std::unique_ptr<Cast> NewCast() {
return std::unique_ptr<Cast>(new CastImpl(LaunchCast<From, To>));
}
class CastFactoryImpl : public CastFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl);
CastFactoryImpl() = default;
~CastFactoryImpl() override = default;
std::unique_ptr<Cast> New(DataType from, DataType to) override {
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
{std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \
NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},
static const std::map<std::pair<DataType, DataType>, std::function<std::unique_ptr<Cast>()>>
new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_CAST_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_CAST_ENTRY
const auto it = new_cast_handle.find(std::make_pair(from, to));
if (it != new_cast_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kGPU, CastFactory, CastFactoryImpl);
} // namespace
} // namespace primitive
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_CUDA_TYPE_SEQ_H_
#define ONEFLOW_CORE_PRIMITIVE_CUDA_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type.h"
#ifdef WITH_CUDA
#include <cuda.h>
#include <cuda_fp16.h>
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
#if CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
#else
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_ALL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#endif // WITH_CUDA
#endif // ONEFLOW_CORE_PRIMITIVE_CUDA_TYPE_SEQ_H_
......@@ -48,14 +48,14 @@ static std::unique_ptr<typename FactoryType::PrimitiveType> NewPrimitive(DeviceT
Args&&... args) {
std::unique_ptr<FactoryType> factory = NewObjUniquePtr<DeviceType, FactoryType>(device_type);
if (!factory) { return nullptr; }
return factory->New(std::forward<Args...>(args)...);
return factory->New(std::forward<Args>(args)...);
}
template<typename FactoryType, typename... Args>
static std::unique_ptr<typename FactoryType::PrimitiveType> NewPrimitive(
const std::string& device_tag, Args&&... args) {
const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(device_tag));
return NewPrimitive<FactoryType, Args...>(device_type, std::forward<Args...>(args)...);
return NewPrimitive<FactoryType, Args...>(device_type, std::forward<Args>(args)...);
}
#define REGISTER_PRIMITIVE_FACTORY(device, Base, Derived) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册