未验证 提交 95e1434b 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add bfloat16 data type (#25402)

上级 3ba7b9b5
...@@ -116,6 +116,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -116,6 +116,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s32: case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>()); return platform::to_void_cast(tensor.data<int32_t>());
case mkldnn::memory::data_type::bf16:
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
default: default:
PADDLE_THROW( PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided.")); platform::errors::InvalidArgument("Wrong mkldnn type provided."));
......
...@@ -61,7 +61,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) { ...@@ -61,7 +61,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
{DataTypeTrait<float>::DataType(), MKLDNNDataType::f32}, {DataTypeTrait<float>::DataType(), MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8}, {DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8},
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32},
{DataTypeTrait<platform::bfloat16>::DataType(), MKLDNNDataType::bf16}};
auto iter = dict.find(static_cast<int>(type)); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::undef; return MKLDNNDataType::undef;
...@@ -74,6 +75,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -74,6 +75,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out); const Tensor& in, Tensor* out);
void* GetDataFromTensor(const Tensor& tensor, MKLDNNDataType type);
#endif #endif
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
......
...@@ -43,3 +43,17 @@ TEST(DataTransform, DataLayoutFunction) { ...@@ -43,3 +43,17 @@ TEST(DataTransform, DataLayoutFunction) {
EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC); EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC);
EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2})); EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2}));
} }
#ifdef PADDLE_WITH_MKLDNN
TEST(DataTransform, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3, 1, 2}), place);
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::bf16);
EXPECT_EQ(in_data, paddle::platform::to_void_cast(
in.data<paddle::platform::bfloat16>()));
}
#endif
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <unordered_map> #include <unordered_map>
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -17,6 +17,8 @@ limitations under the License. */ ...@@ -17,6 +17,8 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -36,15 +38,16 @@ struct DataTypeTrait<void> { ...@@ -36,15 +38,16 @@ struct DataTypeTrait<void> {
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \ #define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type); callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \ #define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \ _ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, int, INT32); \ _ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \ _ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \ _ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ _ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \ _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8) _ForEachDataTypeHelper_(callback, int8_t, INT8)
#define _ForEachDataTypeSmall_(callback) \ #define _ForEachDataTypeSmall_(callback) \
......
...@@ -38,3 +38,25 @@ TEST(DataType, float16) { ...@@ -38,3 +38,25 @@ TEST(DataType, float16) {
std::string type = "::paddle::platform::float16"; std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
} }
TEST(DataType, bfloat16) {
using paddle::framework::Tensor;
using paddle::platform::CPUPlace;
using paddle::platform::bfloat16;
namespace f = paddle::framework;
f::proto::VarType::Type dtype = f::proto::VarType::BF16;
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, dtype);
// test bf16 tensor
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(bfloat16)));
// test bf16 size
EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info
std::string type = "::paddle::platform::bfloat16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}
...@@ -77,6 +77,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -77,6 +77,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
framework::VisitDataType(dst_type, framework::VisitDataType(dst_type,
CastDataType<platform::float16>(in, out, ctx)); CastDataType<platform::float16>(in, out, ctx));
break; break;
case proto::VarType::BF16:
framework::VisitDataType(dst_type,
CastDataType<platform::bfloat16>(in, out, ctx));
break;
case proto::VarType::FP32: case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break; break;
......
...@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) {
paddle::framework::DataLayout::kAnyLayout, paddle::framework::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain); paddle::framework::LibraryType::kPlain);
auto kernel_bf16 = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::BF16, place,
paddle::framework::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp32 = paddle::framework::OpKernelType( auto kernel_fp32 = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::FP32, place, paddle::framework::proto::VarType::FP32, place,
paddle::framework::DataLayout::kAnyLayout, paddle::framework::DataLayout::kAnyLayout,
...@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) {
static_cast<paddle::platform::float16>(in_data_bool[i]).x); static_cast<paddle::platform::float16>(in_data_bool[i]).x);
} }
} }
// data type transform from/to bfloat16
{
paddle::framework::Tensor in;
paddle::framework::Tensor out;
paddle::platform::bfloat16* ptr =
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from bfloat16 to other data types
paddle::framework::TransDataType(kernel_bf16, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to bfloat16
float* in_data_float =
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
paddle::framework::TransDataType(kernel_fp32, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_float[i]).x);
}
// transform double to bfloat16
double* in_data_double =
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
paddle::framework::TransDataType(kernel_fp64, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_double[i]).x);
}
// transform int to bfloat16
int* in_data_int =
in.mutable_data<int>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int[i] = i;
}
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_int[i]).x);
}
// transform int64 to bfloat16
int64_t* in_data_int64 =
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
paddle::framework::TransDataType(kernel_int64, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_int64[i]).x);
}
// transform bool to bfloat16
bool* in_data_bool =
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
paddle::framework::TransDataType(kernel_bool, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
}
}
} }
...@@ -167,6 +167,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num, ...@@ -167,6 +167,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
// more detail see: 180 page of // more detail see: 180 page of
// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf // https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in) #pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
omp_in)
#endif #endif
template <typename T> template <typename T>
......
...@@ -23,6 +23,7 @@ template <typename T> ...@@ -23,6 +23,7 @@ template <typename T>
static ::DLDataType GetDLDataTypeCode() { static ::DLDataType GetDLDataTypeCode() {
::DLDataType dtype; ::DLDataType dtype;
if (std::is_same<T, platform::float16>::value || if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) { std::is_floating_point<T>::value) {
dtype.code = kDLFloat; dtype.code = kDLFloat;
} else if (std::is_unsigned<T>::value) { } else if (std::is_unsigned<T>::value) {
......
...@@ -90,32 +90,6 @@ void MemoryOptimizePass::CollectLifeCycle( ...@@ -90,32 +90,6 @@ void MemoryOptimizePass::CollectLifeCycle(
} }
} }
// TODO(Superjomn) Make this a general help method.
int DataTypeToSpace(framework::proto::VarType_Type type) {
switch (type) {
case framework::proto::VarType_Type_BOOL:
return sizeof(bool);
case framework::proto::VarType_Type_FP32:
return sizeof(float);
case framework::proto::VarType_Type_INT32:
return sizeof(int32_t);
case framework::proto::VarType_Type_INT64:
return sizeof(int64_t);
case framework::proto::VarType_Type_INT16:
return sizeof(int16_t);
case framework::proto::VarType_Type_FP16:
return sizeof(int16_t);
case framework::proto::VarType_Type_FP64:
return sizeof(double);
case framework::proto::VarType_Type_UINT8:
return sizeof(unsigned char);
case framework::proto::VarType_Type_INT8:
return sizeof(int8_t);
default:
PADDLE_THROW("Unknown data type");
}
}
void MemoryOptimizePass::CollectVarMemorySize( void MemoryOptimizePass::CollectVarMemorySize(
space_table_t* space_table) const { space_table_t* space_table) const {
const int fake_batch_size = 1; const int fake_batch_size = 1;
...@@ -163,7 +137,7 @@ void MemoryOptimizePass::CollectVarMemorySize( ...@@ -163,7 +137,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
int size = std::accumulate(shape.begin(), shape.end(), 1, int size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>()); std::multiplies<int>());
(*space_table)[node->Var()->Name()] = (*space_table)[node->Var()->Name()] =
size * DataTypeToSpace(node->Var()->GetDataType()); size * paddle::framework::SizeOfType(node->Var()->GetDataType());
} }
} }
} }
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/inference/lite/engine.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/lite/ut_helper.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/lite/engine.h"
#include "paddle/fluid/operators/lite/ut_helper.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace lite { namespace lite {
......
...@@ -65,13 +65,14 @@ class SplitFunctor { ...@@ -65,13 +65,14 @@ class SplitFunctor {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_ALL_TYPES(macro) \ #define FOR_ALL_TYPES(macro) \
macro(int); \ macro(int); \
macro(float); \ macro(float); \
macro(double); \ macro(double); \
macro(bool); \ macro(bool); \
macro(int64_t); \ macro(int64_t); \
macro(int16_t); \ macro(int16_t); \
macro(uint8_t); \ macro(uint8_t); \
macro(int8_t); \ macro(int8_t); \
macro(::paddle::platform::float16) macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16)
...@@ -34,6 +34,7 @@ namespace math { ...@@ -34,6 +34,7 @@ namespace math {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
template struct SetConstant<platform::CPUDeviceContext, platform::float16>; template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::CPUDeviceContext, float>; template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>; template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>; template struct SetConstant<platform::CPUDeviceContext, int>;
...@@ -41,16 +42,18 @@ template struct SetConstant<platform::CPUDeviceContext, int64_t>; ...@@ -41,16 +42,18 @@ template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>; template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>; template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
#define DEFINE_CPU_TRANS(RANK) \ #define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \ template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \ RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \ template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \ RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \ template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \ template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>; template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>;
DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(1);
......
...@@ -136,6 +136,8 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) ...@@ -136,6 +136,8 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
cc_test(bfloat16_test SRCS bfloat16_test.cc DEPS lod_tensor)
nv_test(test_limit_gpu_memory SRCS test_limit_gpu_memory.cu DEPS gpu_info flags) nv_test(test_limit_gpu_memory SRCS test_limit_gpu_memory.cu DEPS gpu_info flags)
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info) nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <limits>
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x) __declspec(align(x))
#endif
#include <cstring>
#include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace platform {
struct PADDLE_ALIGN(2) bfloat16 {
public:
uint16_t x;
bfloat16() = default;
bfloat16(const bfloat16& o) = default;
bfloat16& operator=(const bfloat16& o) = default;
bfloat16(bfloat16&& o) = default;
bfloat16& operator=(bfloat16&& o) = default;
~bfloat16() = default;
HOSTDEVICE inline explicit bfloat16(float val) {
std::memcpy(&x, reinterpret_cast<char*>(&val) + 2, 2);
}
template <class T>
HOSTDEVICE inline explicit bfloat16(const T& val)
: x(bfloat16(static_cast<float>(val)).x) {}
HOSTDEVICE inline bfloat16& operator=(bool b) {
x = b ? 0x3f80 : 0;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(int8_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(uint8_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(int16_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(uint16_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(int32_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(uint32_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(int64_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(uint64_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(float val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(double val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline explicit operator float() const {
float val = 0.f;
uint16_t temp = x;
memcpy(reinterpret_cast<char*>(&val) + 2, reinterpret_cast<char*>(&temp),
2);
return val;
}
HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }
HOSTDEVICE inline explicit operator int8_t() const {
return static_cast<int8_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint8_t() const {
return static_cast<uint8_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int16_t() const {
return static_cast<int16_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint16_t() const {
return static_cast<uint16_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int32_t() const {
return static_cast<int32_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint32_t() const {
return static_cast<uint32_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int64_t() const {
return static_cast<int64_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint64_t() const {
return static_cast<uint64_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator double() const {
return static_cast<double>(static_cast<float>(*this));
}
};
HOSTDEVICE inline bfloat16 operator+(const bfloat16& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
HOSTDEVICE inline bfloat16 operator-(const bfloat16& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) - static_cast<float>(b));
}
HOSTDEVICE inline bfloat16 operator*(const bfloat16& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) * static_cast<float>(b));
}
HOSTDEVICE inline bfloat16 operator/(const bfloat16& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
}
HOSTDEVICE inline bfloat16 operator-(const bfloat16& a) {
bfloat16 res;
res.x = a.x ^ 0x8000;
return res;
}
HOSTDEVICE inline bfloat16& operator+=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = bfloat16(static_cast<float>(a) + static_cast<float>(b));
return a;
}
HOSTDEVICE inline bfloat16& operator-=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = bfloat16(static_cast<float>(a) - static_cast<float>(b));
return a;
}
HOSTDEVICE inline bfloat16& operator*=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = bfloat16(static_cast<float>(a) * static_cast<float>(b));
return a;
}
HOSTDEVICE inline bfloat16& operator/=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = bfloat16(static_cast<float>(a) / static_cast<float>(b));
return a;
}
HOSTDEVICE inline bfloat16 raw_uint16_to_bfloat16(uint16_t a) {
bfloat16 res;
res.x = a;
return res;
}
HOSTDEVICE inline bool operator==(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) == static_cast<float>(b);
}
HOSTDEVICE inline bool operator!=(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) != static_cast<float>(b);
}
HOSTDEVICE inline bool operator<(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) < static_cast<float>(b);
}
HOSTDEVICE inline bool operator<=(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) <= static_cast<float>(b);
}
HOSTDEVICE inline bool operator>(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) > static_cast<float>(b);
}
HOSTDEVICE inline bool operator>=(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) >= static_cast<float>(b);
}
HOSTDEVICE inline bool(isnan)(const bfloat16& a) {
return (a.x & 0x7FFF) > 0x7F80;
}
HOSTDEVICE inline bool(isinf)(const bfloat16& a) {
return (a.x & 0x7F80) == 0x7F80;
}
HOSTDEVICE inline bool(isfinite)(const bfloat16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
inline std::ostream& operator<<(std::ostream& os, const bfloat16& a) {
os << a.x;
return os;
}
} // namespace platform
} // namespace paddle
namespace std {
template <>
struct is_pod<paddle::platform::bfloat16> {
static const bool value =
is_trivial<paddle::platform::bfloat16>::value &&
is_standard_layout<paddle::platform::bfloat16>::value;
};
template <>
struct is_floating_point<paddle::platform::bfloat16>
: std::integral_constant<
bool, std::is_same<paddle::platform::bfloat16,
typename std::remove_cv<
paddle::platform::bfloat16>::type>::value> {};
template <>
struct is_signed<paddle::platform::bfloat16> {
static const bool value = true;
};
template <>
struct is_unsigned<paddle::platform::bfloat16> {
static const bool value = false;
};
inline bool isnan(const paddle::platform::bfloat16& a) {
return paddle::platform::isnan(a);
}
inline bool isinf(const paddle::platform::bfloat16& a) {
return paddle::platform::isinf(a);
}
template <>
struct numeric_limits<paddle::platform::bfloat16> {
static const bool is_specialized = true;
static const bool is_signed = true;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const float_denorm_style has_denorm = denorm_present;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_to_nearest;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 8;
static const int digits10 = 2;
static const int max_digits10 = 9;
static const int radix = 2;
static const int min_exponent = -125;
static const int min_exponent10 = -37;
static const int max_exponent = 128;
static const int max_exponent10 = 38;
static const bool traps = true;
static const bool tinyness_before = false;
static paddle::platform::bfloat16(min)() {
return paddle::platform::raw_uint16_to_bfloat16(0x007f);
}
static paddle::platform::bfloat16 lowest() {
return paddle::platform::raw_uint16_to_bfloat16(0xff7f);
}
static paddle::platform::bfloat16(max)() {
return paddle::platform::raw_uint16_to_bfloat16(0x7f7f);
}
static paddle::platform::bfloat16 epsilon() {
return paddle::platform::raw_uint16_to_bfloat16(0x3400);
}
static paddle::platform::bfloat16 round_error() {
return paddle::platform::bfloat16(0.5);
}
static paddle::platform::bfloat16 infinity() {
return paddle::platform::raw_uint16_to_bfloat16(0x7f80);
}
static paddle::platform::bfloat16 quiet_NaN() {
return paddle::platform::raw_uint16_to_bfloat16(0xffc1);
}
static paddle::platform::bfloat16 signaling_NaN() {
return paddle::platform::raw_uint16_to_bfloat16(0xff81);
}
static paddle::platform::bfloat16 denorm_min() {
return paddle::platform::raw_uint16_to_bfloat16(0x0001);
}
};
} // namespace std
namespace Eigen {
using bfloat16 = paddle::platform::bfloat16;
template <>
struct NumTraits<bfloat16> : GenericNumTraits<bfloat16> {
enum {
IsSigned = true,
IsInteger = false,
IsComplex = false,
RequireInitialization = false
};
HOSTDEVICE static inline bfloat16 epsilon() {
return paddle::platform::raw_uint16_to_bfloat16(0x3400);
}
HOSTDEVICE static inline bfloat16 dummy_precision() {
return bfloat16(1e-5f);
}
HOSTDEVICE static inline bfloat16 highest() {
return paddle::platform::raw_uint16_to_bfloat16(0x7f7f);
}
HOSTDEVICE static inline bfloat16 lowest() {
return paddle::platform::raw_uint16_to_bfloat16(0xff7f);
}
HOSTDEVICE static inline bfloat16 infinity() {
return paddle::platform::raw_uint16_to_bfloat16(0x7f80);
}
HOSTDEVICE static inline bfloat16 quiet_NaN() {
return paddle::platform::raw_uint16_to_bfloat16(0xffc1);
}
};
namespace numext {
template <>
HOSTDEVICE inline bool(isnan)(const bfloat16& a) {
return (paddle::platform::isnan)(a);
}
template <>
HOSTDEVICE inline bool(isinf)(const bfloat16& a) {
return (paddle::platform::isinf)(a);
}
template <>
HOSTDEVICE inline bool(isfinite)(const bfloat16& a) {
return (paddle::platform::isfinite)(a);
}
template <>
HOSTDEVICE inline bfloat16 exp(const bfloat16& a) {
return bfloat16(::expf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 erf(const bfloat16& a) {
return bfloat16(::erff(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 log(const bfloat16& a) {
return bfloat16(::logf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 tanh(const bfloat16& a) {
return bfloat16(::tanhf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 sqrt(const bfloat16& a) {
return bfloat16(::sqrtf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 ceil(const bfloat16& a) {
return bfloat16(::ceilf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 floor(const bfloat16& a) {
return bfloat16(::floorf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 round(const bfloat16& a) {
return bfloat16(::roundf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 pow(const bfloat16& a, const bfloat16& b) {
return bfloat16(::powf(static_cast<float>(a), static_cast<float>(b)));
}
template <>
HOSTDEVICE inline bfloat16 abs(const bfloat16& a) {
return bfloat16(::fabs(static_cast<float>(a)));
}
} // namespace numext
} // namespace Eigen
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/bfloat16.h"
#include <vector>
#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/init.h"
namespace paddle {
namespace platform {
using bfloat16 = paddle::platform::bfloat16;
TEST(bfloat16, conversion_cpu) {
// Conversion from float
EXPECT_EQ(bfloat16(1.0f).x, 0x3f80);
EXPECT_EQ(bfloat16(0.5f).x, 0x3f00);
EXPECT_EQ(bfloat16(0.33333f).x, 0x3eaa);
EXPECT_EQ(bfloat16(0.0f).x, 0x0000);
EXPECT_EQ(bfloat16(-0.0f).x, 0x8000);
EXPECT_EQ(bfloat16(65504.0f).x, 0x477f);
EXPECT_EQ(bfloat16(65536.0f).x, 0x4780);
// Conversion from double
EXPECT_EQ(bfloat16(1.0).x, 0x3f80);
EXPECT_EQ(bfloat16(0.5).x, 0x3f00);
EXPECT_EQ(bfloat16(0.33333).x, 0x3eaa);
EXPECT_EQ(bfloat16(0.0).x, 0x0000);
EXPECT_EQ(bfloat16(-0.0).x, 0x8000);
EXPECT_EQ(bfloat16(65504.0).x, 0x477f);
EXPECT_EQ(bfloat16(65536.0).x, 0x4780);
// Conversion from int
EXPECT_EQ(bfloat16(-1).x, 0xbf80);
EXPECT_EQ(bfloat16(0).x, 0x0000);
EXPECT_EQ(bfloat16(1).x, 0x3f80);
EXPECT_EQ(bfloat16(2).x, 0x4000);
EXPECT_EQ(bfloat16(3).x, 0x4040);
// Conversion from bool
EXPECT_EQ(bfloat16(true).x, 0x3f80);
EXPECT_EQ(bfloat16(false).x, 0x0000);
// Assignment operator
bfloat16 v_assign;
v_assign = bfloat16(0.f);
EXPECT_EQ(v_assign.x, 0x0000);
v_assign = 0.5f;
EXPECT_EQ(v_assign.x, 0x3f00);
v_assign = 0.33333;
EXPECT_EQ(v_assign.x, 0x3eaa);
v_assign = -1;
EXPECT_EQ(v_assign.x, 0xbf80);
// Conversion operator
EXPECT_EQ(static_cast<float>(bfloat16(0.5f)), 0.5f);
EXPECT_NEAR(static_cast<double>(bfloat16(0.33333)), 0.33333, 0.01);
EXPECT_EQ(static_cast<int>(bfloat16(-1)), -1);
EXPECT_EQ(static_cast<bool>(bfloat16(true)), true);
}
TEST(bfloat16, arithmetic_cpu) {
EXPECT_NEAR(static_cast<float>(bfloat16(1) + bfloat16(1)), 2, 0.001);
EXPECT_EQ(static_cast<float>(bfloat16(5) + bfloat16(-5)), 0);
EXPECT_NEAR(static_cast<float>(bfloat16(0.33333f) + bfloat16(0.66667f)), 1.0f,
0.01);
EXPECT_EQ(static_cast<float>(bfloat16(3) - bfloat16(5)), -2);
EXPECT_NEAR(static_cast<float>(bfloat16(0.66667f) - bfloat16(0.33333f)),
0.33334f, 0.01);
EXPECT_NEAR(static_cast<float>(bfloat16(3.3f) * bfloat16(2.0f)), 6.6f, 0.01);
EXPECT_NEAR(static_cast<float>(bfloat16(-2.1f) * bfloat16(-3.0f)), 6.3f, 0.1);
EXPECT_NEAR(static_cast<float>(bfloat16(2.0f) / bfloat16(3.0f)), 0.66667f,
0.01);
EXPECT_EQ(static_cast<float>(bfloat16(1.0f) / bfloat16(2.0f)), 0.5f);
EXPECT_EQ(static_cast<float>(-bfloat16(512.0f)), -512.0f);
EXPECT_EQ(static_cast<float>(-bfloat16(-512.0f)), 512.0f);
}
TEST(bfloat16, comparison_cpu) {
EXPECT_TRUE(bfloat16(1.0f) == bfloat16(1.0f));
EXPECT_FALSE(bfloat16(-1.0f) == bfloat16(-0.5f));
EXPECT_TRUE(bfloat16(1.0f) != bfloat16(0.5f));
EXPECT_FALSE(bfloat16(-1.0f) != bfloat16(-1.0f));
EXPECT_TRUE(bfloat16(1.0f) < bfloat16(2.0f));
EXPECT_FALSE(bfloat16(-1.0f) < bfloat16(-1.0f));
EXPECT_TRUE(bfloat16(1.0f) <= bfloat16(1.0f));
EXPECT_TRUE(bfloat16(2.0f) > bfloat16(1.0f));
EXPECT_FALSE(bfloat16(-2.0f) > bfloat16(-2.0f));
EXPECT_TRUE(bfloat16(2.0f) >= bfloat16(2.0f));
}
TEST(bfloat16, lod_tensor_cpu) {
framework::LoDTensor lod_tensor;
std::vector<bfloat16> input_data = {bfloat16(1.0f), bfloat16(0.5f),
bfloat16(0.33333f), bfloat16(0.0f)};
EXPECT_EQ(input_data[0].x, 0x3f80);
EXPECT_EQ(input_data[1].x, 0x3f00);
EXPECT_EQ(input_data[2].x, 0x3eaa);
EXPECT_EQ(input_data[3].x, 0x0000);
lod_tensor.Resize({4, 1});
lod_tensor.set_lod(framework::LoD({{0, 2, 4}}));
bfloat16* data_ptr = lod_tensor.mutable_data<bfloat16>(CPUPlace());
EXPECT_NE(data_ptr, nullptr);
EXPECT_EQ(input_data.size(), static_cast<size_t>(lod_tensor.numel()));
for (size_t i = 0; i < input_data.size(); ++i) {
data_ptr[i] = input_data[i];
EXPECT_EQ(data_ptr[i].x, input_data[i].x);
}
}
TEST(bfloat16, floating) {
// compile time assert.
PADDLE_ENFORCE_EQ(
std::is_floating_point<bfloat16>::value, true,
platform::errors::Fatal("std::is_floating_point with bfloat16 data type "
"should be equal to true but it is not"));
}
TEST(bfloat16, print) {
bfloat16 a = bfloat16(1.0f);
std::cout << a << std::endl;
}
// CPU test
TEST(bfloat16, isinf) {
bfloat16 a;
a.x = 0x7f80;
bfloat16 b = bfloat16(INFINITY);
bfloat16 c = static_cast<bfloat16>(INFINITY);
EXPECT_EQ(std::isinf(a), true);
EXPECT_EQ(std::isinf(b), true);
EXPECT_EQ(std::isinf(c), true);
}
TEST(bfloat16, isnan) {
bfloat16 a;
a.x = 0x7fff;
bfloat16 b = bfloat16(NAN);
bfloat16 c = static_cast<bfloat16>(NAN);
EXPECT_EQ(std::isnan(a), true);
EXPECT_EQ(std::isnan(b), true);
EXPECT_EQ(std::isnan(c), true);
}
} // namespace platform
} // namespace paddle
...@@ -161,6 +161,12 @@ inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() { ...@@ -161,6 +161,12 @@ inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::data_type::u8; return mkldnn::memory::data_type::u8;
} }
template <>
inline mkldnn::memory::data_type
MKLDNNGetDataType<paddle::platform::bfloat16>() {
return mkldnn::memory::data_type::bf16;
}
inline void Reorder(mkldnn::memory src, mkldnn::memory dst, inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
const mkldnn::engine& engine) { const mkldnn::engine& engine) {
auto reorder_prim = mkldnn::reorder(src, dst); auto reorder_prim = mkldnn::reorder(src, dst);
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
...@@ -104,6 +105,7 @@ struct ValidDTypeToPyArrayChecker { ...@@ -104,6 +105,7 @@ struct ValidDTypeToPyArrayChecker {
} }
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(float); DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double); DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool); DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
...@@ -119,6 +121,9 @@ inline std::string TensorDTypeToPyDTypeStr( ...@@ -119,6 +121,9 @@ inline std::string TensorDTypeToPyDTypeStr(
if (type == proto_type) { \ if (type == proto_type) { \
if (std::is_same<T, platform::float16>::value) { \ if (std::is_same<T, platform::float16>::value) { \
return "e"; \ return "e"; \
} else if (std::is_same<T, platform::bfloat16>::value) { \
/* NumPy character code of uint16 due to no support for bfloat16 */ \
return "H"; \
} else { \ } else { \
constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \ constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
PADDLE_ENFORCE_EQ( \ PADDLE_ENFORCE_EQ( \
...@@ -262,10 +267,10 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, ...@@ -262,10 +267,10 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place, SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy); zero_copy);
} else if (py::isinstance<py::array_t<uint16_t>>(array)) { } else if (py::isinstance<py::array_t<uint16_t>>(array)) {
// TODO(cql): temporary keeping uint16, which is used for casting float16 // since there is still no support for bfloat16 in NumPy,
// before. It should be depracated later. // uint16 is used for casting bfloat16
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place, SetTensorFromPyArrayT<paddle::platform::bfloat16, P>(self, array, place,
zero_copy); zero_copy);
} else if (py::isinstance<py::array_t<bool>>(array)) { } else if (py::isinstance<py::array_t<bool>>(array)) {
SetTensorFromPyArrayT<bool, P>(self, array, place, zero_copy); SetTensorFromPyArrayT<bool, P>(self, array, place, zero_copy);
} else { } else {
...@@ -479,6 +484,8 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self, ...@@ -479,6 +484,8 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self,
switch (src_type) { switch (src_type) {
case framework::proto::VarType::FP16: case framework::proto::VarType::FP16:
return _sliceAndConcat<paddle::platform::float16>(self, obj, dim); return _sliceAndConcat<paddle::platform::float16>(self, obj, dim);
case framework::proto::VarType::BF16:
return _sliceAndConcat<paddle::platform::bfloat16>(self, obj, dim);
case framework::proto::VarType::FP32: case framework::proto::VarType::FP32:
return _sliceAndConcat<float>(self, obj, dim); return _sliceAndConcat<float>(self, obj, dim);
case framework::proto::VarType::FP64: case framework::proto::VarType::FP64:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册