提交 2dee7f2d 编写于 作者: S seiriosPlus

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into optimize/large_scale_kv_spped

......@@ -16,6 +16,7 @@ else()
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
set(paddle_known_gpu_archs11 "35 50 52 60 61 70 75 80")
endif()
######################################################################################
......@@ -188,6 +189,10 @@ elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs11})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
endif()
add_definitions("-DPADDLE_CUDA_BINVER=\"${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}\"")
......
......@@ -272,7 +272,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib
cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc)
cc_library(generator SRCS generator.cc DEPS enforce place)
# Get the current working branch
execute_process(
......
......@@ -116,6 +116,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>());
case mkldnn::memory::data_type::bf16:
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
default:
PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided."));
......
......@@ -61,7 +61,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
{DataTypeTrait<float>::DataType(), MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8},
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32}};
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32},
{DataTypeTrait<platform::bfloat16>::DataType(), MKLDNNDataType::bf16}};
auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second;
return MKLDNNDataType::undef;
......@@ -74,6 +75,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out);
void* GetDataFromTensor(const Tensor& tensor, MKLDNNDataType type);
#endif
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
......
......@@ -43,3 +43,17 @@ TEST(DataTransform, DataLayoutFunction) {
EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC);
EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2}));
}
#ifdef PADDLE_WITH_MKLDNN
TEST(DataTransform, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3, 1, 2}), place);
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::bf16);
EXPECT_EQ(in_data, paddle::platform::to_void_cast(
in.data<paddle::platform::bfloat16>()));
}
#endif
......@@ -18,6 +18,7 @@
#include <unordered_map>
using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
namespace paddle {
namespace framework {
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include <typeindex>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -36,15 +38,16 @@ struct DataTypeTrait<void> {
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define _ForEachDataTypeSmall_(callback) \
......
......@@ -38,3 +38,25 @@ TEST(DataType, float16) {
std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}
TEST(DataType, bfloat16) {
using paddle::framework::Tensor;
using paddle::platform::CPUPlace;
using paddle::platform::bfloat16;
namespace f = paddle::framework;
f::proto::VarType::Type dtype = f::proto::VarType::BF16;
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, dtype);
// test bf16 tensor
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(bfloat16)));
// test bf16 size
EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info
std::string type = "::paddle::platform::bfloat16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}
......@@ -77,6 +77,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
framework::VisitDataType(dst_type,
CastDataType<platform::float16>(in, out, ctx));
break;
case proto::VarType::BF16:
framework::VisitDataType(dst_type,
CastDataType<platform::bfloat16>(in, out, ctx));
break;
case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break;
......
......@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) {
paddle::framework::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_bf16 = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::BF16, place,
paddle::framework::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp32 = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::FP32, place,
paddle::framework::DataLayout::kAnyLayout,
......@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) {
static_cast<paddle::platform::float16>(in_data_bool[i]).x);
}
}
// data type transform from/to bfloat16
{
paddle::framework::Tensor in;
paddle::framework::Tensor out;
paddle::platform::bfloat16* ptr =
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from bfloat16 to other data types
paddle::framework::TransDataType(kernel_bf16, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to bfloat16
float* in_data_float =
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
paddle::framework::TransDataType(kernel_fp32, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_float[i]).x);
}
// transform double to bfloat16
double* in_data_double =
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
paddle::framework::TransDataType(kernel_fp64, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_double[i]).x);
}
// transform int to bfloat16
int* in_data_int =
in.mutable_data<int>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int[i] = i;
}
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_int[i]).x);
}
// transform int64 to bfloat16
int64_t* in_data_int64 =
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
paddle::framework::TransDataType(kernel_int64, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_int64[i]).x);
}
// transform bool to bfloat16
bool* in_data_bool =
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
paddle::framework::TransDataType(kernel_bool, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
}
}
}
......@@ -167,6 +167,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
// more detail see: 180 page of
// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
omp_in)
#endif
template <typename T>
......@@ -205,6 +207,21 @@ void CheckNanInf<paddle::platform::float16>(
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
template <>
void CheckNanInf<paddle::platform::bfloat16>(
const paddle::platform::bfloat16* value, const size_t numel, int print_num,
const std::string& op_type, const std::string& var_name) {
float sum = 0.0f;
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += static_cast<float>(value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
#endif
template <>
......
......@@ -23,6 +23,7 @@ template <typename T>
static ::DLDataType GetDLDataTypeCode() {
::DLDataType dtype;
if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) {
dtype.code = kDLFloat;
} else if (std::is_unsigned<T>::value) {
......
......@@ -21,10 +21,46 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(int64_t device_id) {
#ifdef PADDLE_WITH_CUDA
static int64_t num_cuda_devices = -1;
static std::once_flag num_devices_init_flag;
static std::deque<std::once_flag> cuda_device_flags;
static std::vector<std::shared_ptr<Generator>> default_cuda_generators;
std::call_once(num_devices_init_flag, []() {
num_cuda_devices = paddle::platform::GetCUDADeviceCount();
cuda_device_flags.resize(num_cuda_devices);
default_cuda_generators.resize(num_cuda_devices);
});
if (device_id < 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuda device id shoule be greater than 0"));
}
std::call_once(cuda_device_flags[device_id], [device_id]() {
default_cuda_generators[device_id] =
std::make_shared<Generator>(GetRandomSeed(), device_id);
VLOG(4) << "initial seed: "
<< default_cuda_generators[device_id]->GetCurrentSeed();
});
return default_cuda_generators[device_id];
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"getDefaultCUDAGenerator only support in CUDA place"));
#endif
}
const std::shared_ptr<Generator>& DefaultCPUGenerator() {
static auto default_cpu_generator =
std::make_shared<Generator>(GetRandomSeed());
......@@ -103,6 +139,7 @@ uint64_t Generator::Seed() {
void Generator::SetCurrentSeed(uint64_t seed) {
std::lock_guard<std::mutex> lock(this->mu_);
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
std::seed_seq seq({seed});
this->engine_->seed(seq);
}
......@@ -123,6 +160,22 @@ uint64_t Generator::Random64() {
return (*engine)();
}
std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
uint64_t increament_offset) {
uint64_t cur_offset = this->state_.thread_offset;
#ifdef PADDLE_WITH_CUDA
std::lock_guard<std::mutex> lock(this->mu_);
this->state_.thread_offset += increament_offset;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Increment Offset only support in CUDA place"));
#endif
return std::make_pair(static_cast<int>(this->state_.current_seed),
cur_offset);
}
void Generator::SetIsInitPy(bool is_init_py) {
this->is_init_py_ = is_init_py;
VLOG(4) << "SetIsInitPy:" << this->is_init_py_;
......
......@@ -38,6 +38,7 @@ static uint64_t GetRandomSeed() {
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};
......@@ -49,6 +50,7 @@ struct Generator {
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
......@@ -59,11 +61,25 @@ struct Generator {
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
}
Generator(uint64_t seed, uint64_t device_id) {
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->state_.device = device_id;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = false; // TODO(zhiqiu): remove it in future
}
Generator(const Generator& other) = delete;
// get random state
......@@ -83,8 +99,11 @@ struct Generator {
uint64_t Random64();
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t increament_offset);
void SetIsInitPy(bool);
bool GetIsInitPy() const;
uint64_t get_device_id() { return this->state_.device; }
private:
GeneratorState state_;
......@@ -105,5 +124,8 @@ std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);
} // namespace framework
} // namespace paddle
......@@ -615,6 +615,16 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// effect.
bool is_fc_params_shared =
mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
if (is_fc_params_shared) {
return;
}
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out);
......
......@@ -90,32 +90,6 @@ void MemoryOptimizePass::CollectLifeCycle(
}
}
// TODO(Superjomn) Make this a general help method.
int DataTypeToSpace(framework::proto::VarType_Type type) {
switch (type) {
case framework::proto::VarType_Type_BOOL:
return sizeof(bool);
case framework::proto::VarType_Type_FP32:
return sizeof(float);
case framework::proto::VarType_Type_INT32:
return sizeof(int32_t);
case framework::proto::VarType_Type_INT64:
return sizeof(int64_t);
case framework::proto::VarType_Type_INT16:
return sizeof(int16_t);
case framework::proto::VarType_Type_FP16:
return sizeof(int16_t);
case framework::proto::VarType_Type_FP64:
return sizeof(double);
case framework::proto::VarType_Type_UINT8:
return sizeof(unsigned char);
case framework::proto::VarType_Type_INT8:
return sizeof(int8_t);
default:
PADDLE_THROW("Unknown data type");
}
}
void MemoryOptimizePass::CollectVarMemorySize(
space_table_t* space_table) const {
const int fake_batch_size = 1;
......@@ -163,7 +137,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
int size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
(*space_table)[node->Var()->Name()] =
size * DataTypeToSpace(node->Var()->GetDataType());
size * paddle::framework::SizeOfType(node->Var()->GetDataType());
}
}
}
......
......@@ -73,7 +73,7 @@ class PD_INFER_DECL Tensor {
class PD_INFER_DECL Predictor {
public:
Predictor() = default;
Predictor() = delete;
~Predictor() {}
// Use for clone
explicit Predictor(std::unique_ptr<paddle::PaddlePredictor>&& pred)
......
......@@ -14,15 +14,16 @@
#include <gtest/gtest.h>
#include "paddle/fluid/inference/lite/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/lite/ut_helper.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/lite/engine.h"
#include "paddle/fluid/operators/lite/ut_helper.h"
namespace paddle {
namespace inference {
namespace lite {
......
......@@ -125,7 +125,7 @@ endfunction()
if(NOT APPLE AND WITH_MKLML)
# RNN1
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz")
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1/model.tar.gz" "rnn1/data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc)
# seq_pool1
......@@ -210,7 +210,7 @@ inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} ana
# transformer, the dataset only works on batch_size=8 now
set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer")
download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar.gz" "temp%2Ftransformer_data.txt.tar.gz")
download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp/transformer_model.tar.gz" "temp/transformer_data.txt.tar.gz")
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
......@@ -219,7 +219,7 @@ inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_test
# ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR}/ocr.tar.gz)
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Focr.tar.gz")
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos/ocr.tar.gz")
endif()
inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
......@@ -235,7 +235,7 @@ set_property(TEST test_analyzer_detect PROPERTY ENVIRONMENT GLOG_vmodule=analysi
# mobilenet with transpose op
set(MOBILENET_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet")
if (NOT EXISTS ${MOBILENET_INSTALL_DIR}/mobilenet.tar.gz)
inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz")
inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos/mobilenet.tar.gz")
endif()
inference_analysis_api_test(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc)
......@@ -363,9 +363,9 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build(${QUANT_IMG_CLASS_TEST_APP} ${QUANT_IMG_CLASS_TEST_APP_SRC})
# MobileNetV1 FP32 vs. Quant INT8
# The FP32 model should already be downloaded for slim Quant unit tests
set(QUANT2_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2")
set(QUANT2_INT8_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2_int8")
download_quant_data(${QUANT2_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf.tar.gz")
download_quant_data(${QUANT2_INT8_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
......@@ -477,9 +477,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
endif()
inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
# disable test_trt_dynamic_shape_ernie_ser_deser temporary
#inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
# EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
# ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
endif()
......
......@@ -44,7 +44,7 @@ void zero_copy_run() {
const int channels = 3;
const int height = 318;
const int width = 318;
float input[batch_size * channels * height * width] = {0};
float *input = new float[batch_size * channels * height * width]();
int shape[4] = {batch_size, channels, height, width};
int shape_size = 4;
......@@ -65,6 +65,7 @@ void zero_copy_run() {
PD_PredictorZeroCopyRun(config, inputs, in_size, &outputs, &out_size);
delete[] input;
delete[] inputs;
delete[] outputs;
}
......
......@@ -112,7 +112,11 @@ TEST(Analyzer_resnet50, compare_determine) {
TEST(Analyzer_resnet50, save_optim_model) {
AnalysisConfig cfg;
std::string optimModelPath = FLAGS_infer_model + "/saved_optim_model";
#ifdef _WIN32
_mkdir(optimModelPath.c_str());
#else
mkdir(optimModelPath.c_str(), 0777);
#endif
SetConfig(&cfg);
SaveOptimModel(&cfg, optimModelPath);
}
......
......@@ -123,7 +123,7 @@ void profile(bool memory_load = false) {
size_t size = GetSize(output[0]);
PADDLE_ENFORCE_GT(size, 0);
int64_t *result = static_cast<int64_t *>(output[0].data.data());
for (size_t i = 0; i < std::min(11UL, size); i++) {
for (size_t i = 0; i < std::min<size_t>(11, size); i++) {
EXPECT_EQ(result[i], chinese_ner_result_data[i]);
}
}
......
......@@ -23,7 +23,7 @@ from PIL import Image
import math
from paddle.dataset.common import download
import tarfile
import StringIO
from six.moves import StringIO
import argparse
random.seed(0)
......@@ -152,7 +152,7 @@ def convert_Imagenet_tar2bin(tar_file, output_file):
idx = 0
for imagedata in dataset.values():
img = Image.open(StringIO.StringIO(imagedata))
img = Image.open(StringIO(imagedata))
img = process_image(img)
np_img = np.array(img)
ofs.write(np_img.astype('float32').tobytes())
......
......@@ -19,7 +19,7 @@ import os
import sys
from paddle.dataset.common import download
import tarfile
import StringIO
from six.moves import StringIO
import hashlib
import tarfile
import argparse
......@@ -191,7 +191,7 @@ def convert_pascalvoc_tar2bin(tar_path, data_out_path):
gt_labels[name_prefix] = tar.extractfile(tarInfo).read()
for line_idx, name_prefix in enumerate(lines):
im = Image.open(StringIO.StringIO(images[name_prefix]))
im = Image.open(StringIO(images[name_prefix]))
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
......
......@@ -25,7 +25,8 @@ endfunction()
function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
message(STATUS "Download inference test stuff from ${URL}/${FILENAME}")
string(REGEX REPLACE "[-%.]" "_" FILENAME_EX ${FILENAME})
string(REGEX REPLACE "[-%./\\]" "_" FILENAME_EX ${FILENAME})
string(REGEX MATCH "[^/\\]+$" DOWNLOAD_NAME ${FILENAME})
set(EXTERNAL_PROJECT_NAME "extern_inference_download_${FILENAME_EX}")
set(UNPACK_DIR "${INSTALL_DIR}/src/${EXTERNAL_PROJECT_NAME}")
ExternalProject_Add(
......@@ -38,7 +39,7 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ${CMAKE_COMMAND} -E chdir ${INSTALL_DIR}
${CMAKE_COMMAND} -E tar xzf ${FILENAME}
${CMAKE_COMMAND} -E tar xzf ${DOWNLOAD_NAME}
UPDATE_COMMAND ""
INSTALL_COMMAND ""
)
......
......@@ -62,11 +62,11 @@ __global__ void affine_grid_kernel(const int count, int n, int out_h, int out_w,
int theta_offset = n * 6; // 2 * 3;
// affine from (h_coor, w_coor) to (x, y)
output[index * 2] = theta[theta_offset] * h_coor +
theta[theta_offset + 1] * w_coor +
output[index * 2] = theta[theta_offset] * w_coor +
theta[theta_offset + 1] * h_coor +
theta[theta_offset + 2];
output[index * 2 + 1] = theta[theta_offset + 3] * h_coor +
theta[theta_offset + 4] * w_coor +
output[index * 2 + 1] = theta[theta_offset + 3] * w_coor +
theta[theta_offset + 4] * h_coor +
theta[theta_offset + 5];
}
}
......@@ -86,13 +86,13 @@ __global__ void affine_grid_grad_kernel(const int count, int n, int out_h,
int theta_offset = n * 6; // 2 * 3;
T out_grad_x = out_grad[index * 2];
platform::CudaAtomicAdd(theta_grad + theta_offset, out_grad_x * h_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset, out_grad_x * w_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 1, out_grad_x * h_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x);
T out_grad_y = out_grad[index * 2 + 1];
platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_y * w_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 4, out_grad_y * h_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y);
}
}
......
......@@ -166,10 +166,22 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
const int& dtype = ctx->Attrs().Get<int>("dtype");
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3), true,
platform::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (ctx->IsRuntime()) {
const int& dtype = ctx->Attrs().Get<int>("dtype");
if (dtype == framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
......
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/bernoulli_op.h"
......
......@@ -37,41 +37,42 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM");
auto in_dims = ctx->GetInputDim("Input");
auto init_dims = ctx->GetInputDim("InitH");
auto init_h_dims = ctx->GetInputDim("InitH");
auto init_c_dims = ctx->GetInputDim("InitC");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_dims.size(), 3,
PADDLE_ENFORCE_EQ(init_h_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_dims.size()));
init_h_dims.size()));
PADDLE_ENFORCE_EQ(in_dims[1], init_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_dims[1] is %d.",
in_dims[1], init_dims[1]));
PADDLE_ENFORCE_EQ(in_dims[2], init_dims[2],
PADDLE_ENFORCE_EQ(
in_dims[1], init_h_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_h_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_h_dims[1] is %d.",
in_dims[1], init_h_dims[1]));
PADDLE_ENFORCE_EQ(init_c_dims, init_h_dims,
platform::errors::InvalidArgument(
"The in_dims[2] (Input dims) and init_dims[2] (InitH "
"dims) should be equal. But "
"received in_dims[2] is %d and init_dims[2] is %d.",
in_dims[2], init_dims[2]));
"The InitC dims and InitH "
"dims should be equal. But "
"received init_c_dims is %d and init_h_dims is %d.",
init_c_dims, init_h_dims));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
auto last_dims = init_dims;
last_dims[0] = is_bidirec ? last_dims[0] * 2 : last_dims[0];
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("LastH", last_dims);
ctx->SetOutputDim("LastC", last_dims);
ctx->SetOutputDim("LastH", init_c_dims);
ctx->SetOutputDim("LastC", init_h_dims);
}
protected:
......@@ -95,7 +96,7 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"different batch)"
"batch_size is the instance number of this batch"
"input_size is the hidden size of the input."
"input_hidden_size and the hidden_size in the next may not be same");
"input_size and the hidden_size in the next may not be same");
AddInput("InitH",
"(Tensor) the initial hidden state of the LSTM"
"input. This is a tensor with shape (num_layers x batch_size x "
......@@ -154,6 +155,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddAttr<std::vector<int>>("sequence_length",
"(vector<int>) When the input data is padding, "
"set this parameter. This parameter represents "
"the variable sequence"
"lengths in a batch. The size of the vector has "
"to equal the batch_size.")
.SetDefault({});
AddComment(R"DOC(
CUDNN LSTM implementation
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
......@@ -55,50 +56,96 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int num_layers = ctx.Attr<int>("num_layers");
bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed");
auto sequence_length = ctx.Attr<std::vector<int>>("sequence_length");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache();
int seq_length = x->dims()[0];
int batch_size = x->dims()[1];
int input_size = x->dims()[2];
int weight_numel = w->numel();
bool state_initialized = state_out->IsInitialized() ? true : false;
auto input_w_numel = w->numel();
auto seq_len = x->dims()[0];
auto batch_size = x->dims()[1];
auto input_dim = x->dims()[2];
size_t workspace_size;
size_t reserve_size;
bool state_initialized = state_out->IsInitialized() ? true : false;
cudnnDataType_t cudnn_type = platform::ToCudnnDataType(
framework::ToDataType(std::type_index(typeid(T))));
cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size,
input_dim, hidden_size, num_layers, dropout_prob,
is_bidirec, seed, input_w_numel, &reserve_size,
state_out, state_initialized, cudnn_type);
platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
state_initialized, is_bidirec);
rnn.Create<T>(handle, ctx.GetPlace(), sequence_length, &workspace_size,
&reserve_size, state_out);
framework::Tensor workspace_data_;
workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
auto *reserve_data = reserve->mutable_data<uint8_t>(
{static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
if (is_test) {
// for inference
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
x_data, cudnn_rnn_cache->hx_desc_, init_h_data,
cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_,
w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_,
last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data,
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_));
if (sequence_length.empty()) {
// for inference
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(),
last_h_data, rnn.cy_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size));
} else {
#if CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnRNNForwardInferenceEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(),
init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data,
rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data,
rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr, platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
} else {
// for train
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
x_data, cudnn_rnn_cache->hx_desc_, init_h_data,
cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_,
w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_,
last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data,
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_, reserve_data, reserve_size));
if (sequence_length.empty()) {
// for train
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(),
last_h_data, rnn.cy_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
} else {
#if CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnRNNForwardTrainingEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(),
init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data,
rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data,
rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr, platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
delete cudnn_rnn_cache;
}
};
......@@ -156,44 +203,74 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
int seed = ctx.Attr<int>("seed");
auto sequence_length = ctx.Attr<std::vector<int>>("sequence_length");
CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache();
int seq_length = input_dims[0];
int batch_size = input->dims()[1];
int input_size = input->dims()[2];
int weight_numel = weight->numel();
auto input_w_numel = weight->numel();
auto seq_len = input_dims[0];
auto batch_size = input->dims()[1];
auto input_dim = input->dims()[2];
size_t workspace_size;
size_t reserve_size;
cudnnDataType_t cudnn_type = platform::ToCudnnDataType(
framework::ToDataType(std::type_index(typeid(T))));
cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size,
input_dim, hidden_size, num_layers, dropout_prob,
is_bidirec, seed, input_w_numel, &reserve_size,
const_cast<Tensor *>(state_out), true, cudnn_type);
auto work_data = cudnn_rnn_cache->workspace_data_.data<uint8_t>();
platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
true, is_bidirec);
rnn.Create<T>(handle, ctx.GetPlace(), sequence_length, &workspace_size,
&reserve_size, const_cast<Tensor *>(state_out));
framework::Tensor workspace_data_;
workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
const uint8_t *reserve_data = reserve->data<uint8_t>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->y_desc_,
out_data, cudnn_rnn_cache->y_desc_, out_grad_data,
cudnn_rnn_cache->hy_desc_, last_h_grad_data, cudnn_rnn_cache->cy_desc_,
last_c_grad_data, cudnn_rnn_cache->w_desc_, weight_data,
cudnn_rnn_cache->hx_desc_, init_h_data, cudnn_rnn_cache->cx_desc_,
init_c_data, cudnn_rnn_cache->x_desc_, in_grad_data,
cudnn_rnn_cache->hx_desc_, init_h_grad_data, cudnn_rnn_cache->cx_desc_,
init_c_grad_data, work_data, cudnn_rnn_cache->workspace_size_,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
input->data<T>(), cudnn_rnn_cache->hx_desc_, init_h->data<T>(),
cudnn_rnn_cache->y_desc_, out->data<T>(),
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->w_desc_,
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
delete cudnn_rnn_cache;
if (sequence_length.empty()) {
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_desc(), out_data,
rnn.y_desc(), out_grad_data, rnn.hy_desc(), last_h_grad_data,
rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.x_desc(),
in_grad_data, rnn.hx_desc(), init_h_grad_data, rnn.cx_desc(),
init_c_grad_data, workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), input->data<T>(),
rnn.hx_desc(), init_h->data<T>(), rnn.y_desc(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.w_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
} else {
#if CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
out_grad_data, nullptr, nullptr, rnn.hy_desc(), last_h_grad_data,
rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.x_seq_desc(), in_grad_data, rnn.hx_desc(), init_h_grad_data,
rnn.cx_desc(), init_c_grad_data, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.hx_desc(), init_h->data<T>(), rnn.y_seq_desc(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.w_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr,
platform::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"));
#endif
}
}
};
......
......@@ -62,6 +62,34 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
gpu_dev_ctx.Wait();
#else
PADDLE_THROW("Unexpected branch");
#endif
return true;
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
auto& xpu_dev_ctx = static_cast<const platform::XPUDeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
char* p = reinterpret_cast<char*>(dest);
while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, place),
reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write;
total_written += size_to_write;
input->Skip(size_to_write);
}
xpu_dev_ctx.Wait();
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr,
platform::errors::Unimplemented(
"Not supported XPU, please compile with option WITH_XPU=ON."));
#endif
return true;
}
......
......@@ -96,6 +96,42 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
}
}
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
uint64_t increment) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed, idx, increment, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed, idx, increment, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
dest = s;
}
}
mask_data[idx] = mask;
dst[idx] = dest;
}
}
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
......@@ -150,6 +186,17 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
}
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
auto seed_offset = gen_cuda->IncrementOffset(1);
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, seed_offset.second);
return;
}
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
......
......@@ -31,6 +31,15 @@ struct ModFunctor {
}
};
template <typename T>
struct InverseModFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
T res = b % a;
if ((res != 0) && ((res < 0) != (a < 0))) res += a;
return res;
}
};
template <typename T>
struct ModFunctorFP {
inline HOSTDEVICE T operator()(T a, T b) const {
......@@ -40,13 +49,29 @@ struct ModFunctorFP {
}
};
template <typename T>
struct InverseModFunctorFP {
inline HOSTDEVICE T operator()(T a, T b) const {
T res = fmod(b, a);
if ((res != 0) && ((a < 0) != (res < 0))) res += a;
return res;
}
};
template <typename DeviceContext, typename T>
void elementwise_mod(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctor<T>(), z);
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseModFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseModFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T>
......@@ -54,8 +79,15 @@ void elementwise_mod_fp(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctorFP<T>(), z);
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(
ctx, x, y, axis, ModFunctorFP<T>(), z);
} else {
ElementwiseComputeEx<InverseModFunctorFP<T>, DeviceContext, T>(
ctx, x, y, axis, InverseModFunctorFP<T>(), z);
}
}
template <typename DeviceContext, typename T>
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
......@@ -24,15 +25,20 @@ template <typename T>
struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
unsigned int offset_ = 0;
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}
__host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset)
: mean_(mean), std_(std), seed_(seed), offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
unsigned int new_n = n + offset_;
rng.discard(new_n);
return dist(rng);
}
};
......@@ -43,9 +49,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
......@@ -56,9 +64,23 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
T* data = tensor->mutable_data<T>(context.GetPlace());
int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset));
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
}
}
};
......@@ -69,17 +91,33 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second));
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
}
}
};
} // namespace operators
......
......@@ -67,7 +67,7 @@ static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) {
scale_tensor[0], 1,
platform::errors::InvalidArgument(
"Scale's shape must be 1, but got shape = %d .", scale_tensor[0]));
// out_w = -1;
out_w = -1;
} else {
auto scale = ctx->Attrs().Get<std::vector<float>>("scale");
if (scale.size() > 0) {
......@@ -159,8 +159,8 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
platform::errors::InvalidArgument(
"Scale's shape must be 2 or 1, but got shape = %d .",
scale_tensor[0]));
// out_h = -1;
// out_w = -1;
out_h = -1;
out_w = -1;
} else {
auto scale = ctx->Attrs().Get<std::vector<float>>("scale");
if (scale.size() > 0) {
......@@ -264,9 +264,9 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
platform::errors::InvalidArgument(
"Scale's shape must be 3 or 1, but got shape = %d .",
scale_tensor[0]));
// out_d = -1;
// out_h = -1;
// out_w = -1;
out_d = -1;
out_h = -1;
out_w = -1;
} else {
auto scale = ctx->Attrs().Get<std::vector<float>>("scale");
if (scale.size() > 0) {
......@@ -633,6 +633,9 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateV2GradNoNeedBufferVarsInferer,
} // namespace operators
} // namespace paddle
// interp_v2 support scale_factor whose input type is list, this operation is
// not
// compatible with interp_op, so a new one is added in paddle2.0
namespace ops = paddle::operators;
REGISTER_OPERATOR(bilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
......
......@@ -836,12 +836,12 @@ static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
float scale_w = -1;
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_w = new_size[0];
} else {
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
......@@ -887,8 +887,11 @@ static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1.0) / (out_w - 1.0)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
int in_cw = c * in_w;
......@@ -924,14 +927,14 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
float scale_w = -1;
float scale_h = -1;
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_h = new_size[0];
out_w = new_size[1];
} else {
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
......@@ -993,12 +996,18 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
int in_hw = in_h * in_w;
......@@ -1048,6 +1057,9 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
float scale_w = -1;
float scale_d = -1;
float scale_h = -1;
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
......@@ -1055,9 +1067,6 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
out_h = new_size[1];
out_w = new_size[2];
} else {
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
......@@ -1129,16 +1138,25 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
int in_dhw = in_d * in_h * in_w;
......@@ -1230,8 +1248,11 @@ static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx,
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
int in_cw = c * in_w;
int out_cw = c * out_w;
......@@ -1333,12 +1354,18 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
int in_hw = in_h * in_w;
......@@ -1464,16 +1491,25 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
int in_dhw = in_d * in_h * in_w;
......
......@@ -783,12 +783,13 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
int out_w = ctx.Attr<int>("out_w");
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
float scale_w = -1.;
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_w = new_size[0];
} else {
float scale_w = -1;
// float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
......@@ -833,8 +834,11 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
if ("linear" == interp_method) {
LinearInterpolation<T>(input, output, ratio_w, in_w, n, c, out_w,
......@@ -856,6 +860,8 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_h = -1;
float scale_w = -1;
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
......@@ -864,8 +870,6 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
out_h = new_size[0];
out_w = new_size[1];
} else {
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
......@@ -925,12 +929,18 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
if ("bilinear" == interp_method) {
......@@ -962,6 +972,10 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
......@@ -970,9 +984,6 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
out_h = new_size[1];
out_w = new_size[2];
} else {
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
......@@ -1043,16 +1054,25 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
if ("trilinear" == interp_method) {
......@@ -1127,8 +1147,11 @@ static void Interpolate1DCPUBwd(const framework::ExecutionContext& ctx,
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
if ("linear" == interp_method) {
LinearInterpolationGrad<T>(output_grad, input_grad, ratio_w, in_w, n, c,
......@@ -1216,12 +1239,18 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
if ("bilinear" == interp_method) {
......@@ -1327,16 +1356,25 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
: static_cast<float>(new_scale_w);
}
if ("trilinear" == interp_method) {
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -19,6 +20,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) {
......@@ -35,15 +38,27 @@ template <typename T>
class CUDALinspaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* start_t = context.Input<framework::Tensor>("Start");
auto* stop_t = context.Input<framework::Tensor>("Stop");
auto* pre_start = context.Input<framework::Tensor>("Start");
auto* pre_stop = context.Input<framework::Tensor>("Stop");
auto* num_t = context.Input<framework::Tensor>("Num");
auto* out = context.Output<framework::Tensor>("Out");
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
Tensor start_t;
Tensor stop_t;
auto start_dtype =
framework::OpKernelType(pre_start->type(), context.GetPlace());
auto stop_dtype =
framework::OpKernelType(pre_stop->type(), context.GetPlace());
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
framework::TensorCopy(start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0];
framework::TensorCopy(*stop_t, platform::CPUPlace(), &n);
framework::TensorCopy(stop_t, platform::CPUPlace(), &n);
T stop = n.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
int32_t num = n.data<int32_t>()[0];
......
......@@ -14,20 +14,38 @@ limitations under the License. */
#pragma once
#include <functional>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CPULinspaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
T start = context.Input<framework::Tensor>("Start")->data<T>()[0];
T stop = context.Input<framework::Tensor>("Stop")->data<T>()[0];
auto* pre_start = context.Input<framework::Tensor>("Start");
auto* pre_stop = context.Input<framework::Tensor>("Stop");
int32_t num = context.Input<framework::Tensor>("Num")->data<int32_t>()[0];
auto* out = context.Output<framework::Tensor>("Out");
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
Tensor start_t;
Tensor stop_t;
auto start_dtype =
framework::OpKernelType(pre_start->type(), context.GetPlace());
auto stop_dtype =
framework::OpKernelType(pre_stop->type(), context.GetPlace());
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
T start = start_t.data<T>()[0];
T stop = stop_t.data<T>()[0];
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
out->Resize(framework::make_ddim({num}));
......
......@@ -65,13 +65,14 @@ class SplitFunctor {
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16)
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16)
......@@ -34,6 +34,7 @@ namespace math {
using float16 = paddle::platform::float16;
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>;
......@@ -41,16 +42,18 @@ template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>;
DEFINE_CPU_TRANS(1);
......
......@@ -1055,7 +1055,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
astream.wait();
filter_grad->set_layout(DataLayout::kMKLDNN);
filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
// in OneDNN groups in convolution are treated as separate dimension
// which is not the case in paddlepaddle
auto filter_fmt = GetMKLDNNFormat(*diff_weights_memory_p);
filter_grad->set_format(platform::MKLDNNFormatForSize(
g > 1 ? weights_tz.size() - 1 : weights_tz.size(), filter_fmt));
}
if (input_grad) {
auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/uniform_random_op.h"
......@@ -49,15 +50,23 @@ class GPURandintKernel : public framework::OpKernel<T> {
int64_t size = out->numel();
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
/*
std::minstd_rand engine;
if (seed == 0) {
std::random_device rd;
seed = rd();
}
engine.seed(seed);
*/
std::uniform_int_distribution<> dist(context.Attr<int>("low"),
context.Attr<int>("high") - 1);
for (int64_t i = 0; i < size; ++i) data[i] = dist(engine);
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
if (platform::is_gpu_place(context.GetPlace())) {
// Copy tensor to out
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
REGISTER_OP_CUDA_KERNEL(logsumexp,
......@@ -20,8 +19,3 @@ REGISTER_OP_CUDA_KERNEL(logsumexp,
float, ops::LogsumexpFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::LogsumexpFunctor>);
REGISTER_OP_CUDA_KERNEL(
logsumexp_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::LogsumexpGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::LogsumexpGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// .part used to speed up nvcc compile
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
REGISTER_OP_CUDA_KERNEL(
logsumexp_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::LogsumexpGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::LogsumexpGradFunctor>);
......@@ -53,7 +53,7 @@ REGISTER_OPERATOR(
size, ops::SizeOp, ops::SizeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int32_t>,
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int64_t>,
ops::SizeKernel<paddle::platform::float16>,
ops::SizeKernel<float>, ops::SizeKernel<double>,
ops::SizeKernel<bool>);
......@@ -16,7 +16,7 @@ limitations under the License. */
REGISTER_OP_CUDA_KERNEL(
size, paddle::operators::SizeKernel<int>,
paddle::operators::SizeKernel<int32_t>,
paddle::operators::SizeKernel<int64_t>,
paddle::operators::SizeKernel<paddle::platform::float16>,
paddle::operators::SizeKernel<float>, paddle::operators::SizeKernel<bool>,
paddle::operators::SizeKernel<double>);
......@@ -26,8 +26,18 @@ class SizeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out");
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace());
out_data[0] = in_t->numel();
auto place = ctx.GetPlace();
auto out_data = out_t->mutable_data<int64_t>(place);
auto cpu_place = platform::CPUPlace();
if (place == cpu_place) {
out_data[0] = in_t->numel();
} else {
Tensor cpu_tensor;
auto cpu_data =
cpu_tensor.mutable_data<int64_t>(out_t->dims(), cpu_place);
cpu_data[0] = in_t->numel();
TensorCopy(cpu_tensor, place, out_t);
}
}
};
} // namespace operators
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include <limits>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -46,6 +47,37 @@ struct TruncatedNormal {
}
};
template <typename T>
struct TruncatedNormalOffset {
T mean, std;
T a_normal_cdf;
T b_normal_cdf;
unsigned int seed;
T numeric_min;
int offset_;
__host__ __device__ TruncatedNormalOffset(T mean, T std, T numeric_min,
int seed, int offset)
: mean(mean),
std(std),
seed(seed),
numeric_min(numeric_min),
offset_(offset) {
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<T> dist(numeric_min, 1);
rng.discard(n);
T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
}
};
template <typename T>
class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
public:
......@@ -54,14 +86,31 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
seed_offset.first, seed_offset.second));
}
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
......
......@@ -51,6 +51,39 @@ struct UniformGenerator {
}
};
template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
......@@ -89,10 +122,11 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
}
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T min = static_cast<T>(context.Attr<float>("min"));
......@@ -104,10 +138,23 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
T diag_val = static_cast<T>(context.Attr<float>("diag_val"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset));
} else {
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
}
}
};
......
......@@ -136,6 +136,8 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
cc_test(bfloat16_test SRCS bfloat16_test.cc DEPS lod_tensor)
nv_test(test_limit_gpu_memory SRCS test_limit_gpu_memory.cu DEPS gpu_info flags)
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <limits>
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x) __declspec(align(x))
#endif
#include <cstring>
#include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace platform {
struct PADDLE_ALIGN(2) bfloat16 {
public:
uint16_t x;
bfloat16() = default;
bfloat16(const bfloat16& o) = default;
bfloat16& operator=(const bfloat16& o) = default;
bfloat16(bfloat16&& o) = default;
bfloat16& operator=(bfloat16&& o) = default;
~bfloat16() = default;
HOSTDEVICE inline explicit bfloat16(float val) {
std::memcpy(&x, reinterpret_cast<char*>(&val) + 2, 2);
}
template <class T>
HOSTDEVICE inline explicit bfloat16(const T& val)
: x(bfloat16(static_cast<float>(val)).x) {}
HOSTDEVICE inline bfloat16& operator=(bool b) {
x = b ? 0x3f80 : 0;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(int8_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(uint8_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(int16_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(uint16_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(int32_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(uint32_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(int64_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(uint64_t val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(float val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline bfloat16& operator=(double val) {
x = bfloat16(val).x;
return *this;
}
HOSTDEVICE inline explicit operator float() const {
float val = 0.f;
uint16_t temp = x;
memcpy(reinterpret_cast<char*>(&val) + 2, reinterpret_cast<char*>(&temp),
2);
return val;
}
HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }
HOSTDEVICE inline explicit operator int8_t() const {
return static_cast<int8_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint8_t() const {
return static_cast<uint8_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int16_t() const {
return static_cast<int16_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint16_t() const {
return static_cast<uint16_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int32_t() const {
return static_cast<int32_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint32_t() const {
return static_cast<uint32_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int64_t() const {
return static_cast<int64_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint64_t() const {
return static_cast<uint64_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator double() const {
return static_cast<double>(static_cast<float>(*this));
}
};
HOSTDEVICE inline bfloat16 operator+(const bfloat16& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
HOSTDEVICE inline bfloat16 operator-(const bfloat16& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) - static_cast<float>(b));
}
HOSTDEVICE inline bfloat16 operator*(const bfloat16& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) * static_cast<float>(b));
}
HOSTDEVICE inline bfloat16 operator/(const bfloat16& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
}
HOSTDEVICE inline bfloat16 operator-(const bfloat16& a) {
bfloat16 res;
res.x = a.x ^ 0x8000;
return res;
}
HOSTDEVICE inline bfloat16& operator+=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = bfloat16(static_cast<float>(a) + static_cast<float>(b));
return a;
}
HOSTDEVICE inline bfloat16& operator-=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = bfloat16(static_cast<float>(a) - static_cast<float>(b));
return a;
}
HOSTDEVICE inline bfloat16& operator*=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = bfloat16(static_cast<float>(a) * static_cast<float>(b));
return a;
}
HOSTDEVICE inline bfloat16& operator/=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = bfloat16(static_cast<float>(a) / static_cast<float>(b));
return a;
}
HOSTDEVICE inline bfloat16 raw_uint16_to_bfloat16(uint16_t a) {
bfloat16 res;
res.x = a;
return res;
}
HOSTDEVICE inline bool operator==(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) == static_cast<float>(b);
}
HOSTDEVICE inline bool operator!=(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) != static_cast<float>(b);
}
HOSTDEVICE inline bool operator<(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) < static_cast<float>(b);
}
HOSTDEVICE inline bool operator<=(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) <= static_cast<float>(b);
}
HOSTDEVICE inline bool operator>(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) > static_cast<float>(b);
}
HOSTDEVICE inline bool operator>=(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) >= static_cast<float>(b);
}
HOSTDEVICE inline bool(isnan)(const bfloat16& a) {
return (a.x & 0x7FFF) > 0x7F80;
}
HOSTDEVICE inline bool(isinf)(const bfloat16& a) {
return (a.x & 0x7F80) == 0x7F80;
}
HOSTDEVICE inline bool(isfinite)(const bfloat16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
inline std::ostream& operator<<(std::ostream& os, const bfloat16& a) {
os << a.x;
return os;
}
} // namespace platform
} // namespace paddle
namespace std {
template <>
struct is_pod<paddle::platform::bfloat16> {
static const bool value =
is_trivial<paddle::platform::bfloat16>::value &&
is_standard_layout<paddle::platform::bfloat16>::value;
};
template <>
struct is_floating_point<paddle::platform::bfloat16>
: std::integral_constant<
bool, std::is_same<paddle::platform::bfloat16,
typename std::remove_cv<
paddle::platform::bfloat16>::type>::value> {};
template <>
struct is_signed<paddle::platform::bfloat16> {
static const bool value = true;
};
template <>
struct is_unsigned<paddle::platform::bfloat16> {
static const bool value = false;
};
inline bool isnan(const paddle::platform::bfloat16& a) {
return paddle::platform::isnan(a);
}
inline bool isinf(const paddle::platform::bfloat16& a) {
return paddle::platform::isinf(a);
}
template <>
struct numeric_limits<paddle::platform::bfloat16> {
static const bool is_specialized = true;
static const bool is_signed = true;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const float_denorm_style has_denorm = denorm_present;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_to_nearest;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 8;
static const int digits10 = 2;
static const int max_digits10 = 9;
static const int radix = 2;
static const int min_exponent = -125;
static const int min_exponent10 = -37;
static const int max_exponent = 128;
static const int max_exponent10 = 38;
static const bool traps = true;
static const bool tinyness_before = false;
static paddle::platform::bfloat16(min)() {
return paddle::platform::raw_uint16_to_bfloat16(0x007f);
}
static paddle::platform::bfloat16 lowest() {
return paddle::platform::raw_uint16_to_bfloat16(0xff7f);
}
static paddle::platform::bfloat16(max)() {
return paddle::platform::raw_uint16_to_bfloat16(0x7f7f);
}
static paddle::platform::bfloat16 epsilon() {
return paddle::platform::raw_uint16_to_bfloat16(0x3400);
}
static paddle::platform::bfloat16 round_error() {
return paddle::platform::bfloat16(0.5);
}
static paddle::platform::bfloat16 infinity() {
return paddle::platform::raw_uint16_to_bfloat16(0x7f80);
}
static paddle::platform::bfloat16 quiet_NaN() {
return paddle::platform::raw_uint16_to_bfloat16(0xffc1);
}
static paddle::platform::bfloat16 signaling_NaN() {
return paddle::platform::raw_uint16_to_bfloat16(0xff81);
}
static paddle::platform::bfloat16 denorm_min() {
return paddle::platform::raw_uint16_to_bfloat16(0x0001);
}
};
} // namespace std
namespace Eigen {
using bfloat16 = paddle::platform::bfloat16;
template <>
struct NumTraits<bfloat16> : GenericNumTraits<bfloat16> {
enum {
IsSigned = true,
IsInteger = false,
IsComplex = false,
RequireInitialization = false
};
HOSTDEVICE static inline bfloat16 epsilon() {
return paddle::platform::raw_uint16_to_bfloat16(0x3400);
}
HOSTDEVICE static inline bfloat16 dummy_precision() {
return bfloat16(1e-5f);
}
HOSTDEVICE static inline bfloat16 highest() {
return paddle::platform::raw_uint16_to_bfloat16(0x7f7f);
}
HOSTDEVICE static inline bfloat16 lowest() {
return paddle::platform::raw_uint16_to_bfloat16(0xff7f);
}
HOSTDEVICE static inline bfloat16 infinity() {
return paddle::platform::raw_uint16_to_bfloat16(0x7f80);
}
HOSTDEVICE static inline bfloat16 quiet_NaN() {
return paddle::platform::raw_uint16_to_bfloat16(0xffc1);
}
};
namespace numext {
template <>
HOSTDEVICE inline bool(isnan)(const bfloat16& a) {
return (paddle::platform::isnan)(a);
}
template <>
HOSTDEVICE inline bool(isinf)(const bfloat16& a) {
return (paddle::platform::isinf)(a);
}
template <>
HOSTDEVICE inline bool(isfinite)(const bfloat16& a) {
return (paddle::platform::isfinite)(a);
}
template <>
HOSTDEVICE inline bfloat16 exp(const bfloat16& a) {
return bfloat16(::expf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 erf(const bfloat16& a) {
return bfloat16(::erff(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 log(const bfloat16& a) {
return bfloat16(::logf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 tanh(const bfloat16& a) {
return bfloat16(::tanhf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 sqrt(const bfloat16& a) {
return bfloat16(::sqrtf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 ceil(const bfloat16& a) {
return bfloat16(::ceilf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 floor(const bfloat16& a) {
return bfloat16(::floorf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 round(const bfloat16& a) {
return bfloat16(::roundf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline bfloat16 pow(const bfloat16& a, const bfloat16& b) {
return bfloat16(::powf(static_cast<float>(a), static_cast<float>(b)));
}
template <>
HOSTDEVICE inline bfloat16 abs(const bfloat16& a) {
return bfloat16(::fabs(static_cast<float>(a)));
}
} // namespace numext
} // namespace Eigen
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/bfloat16.h"
#include <vector>
#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/init.h"
namespace paddle {
namespace platform {
using bfloat16 = paddle::platform::bfloat16;
TEST(bfloat16, conversion_cpu) {
// Conversion from float
EXPECT_EQ(bfloat16(1.0f).x, 0x3f80);
EXPECT_EQ(bfloat16(0.5f).x, 0x3f00);
EXPECT_EQ(bfloat16(0.33333f).x, 0x3eaa);
EXPECT_EQ(bfloat16(0.0f).x, 0x0000);
EXPECT_EQ(bfloat16(-0.0f).x, 0x8000);
EXPECT_EQ(bfloat16(65504.0f).x, 0x477f);
EXPECT_EQ(bfloat16(65536.0f).x, 0x4780);
// Conversion from double
EXPECT_EQ(bfloat16(1.0).x, 0x3f80);
EXPECT_EQ(bfloat16(0.5).x, 0x3f00);
EXPECT_EQ(bfloat16(0.33333).x, 0x3eaa);
EXPECT_EQ(bfloat16(0.0).x, 0x0000);
EXPECT_EQ(bfloat16(-0.0).x, 0x8000);
EXPECT_EQ(bfloat16(65504.0).x, 0x477f);
EXPECT_EQ(bfloat16(65536.0).x, 0x4780);
// Conversion from int
EXPECT_EQ(bfloat16(-1).x, 0xbf80);
EXPECT_EQ(bfloat16(0).x, 0x0000);
EXPECT_EQ(bfloat16(1).x, 0x3f80);
EXPECT_EQ(bfloat16(2).x, 0x4000);
EXPECT_EQ(bfloat16(3).x, 0x4040);
// Conversion from bool
EXPECT_EQ(bfloat16(true).x, 0x3f80);
EXPECT_EQ(bfloat16(false).x, 0x0000);
// Assignment operator
bfloat16 v_assign;
v_assign = bfloat16(0.f);
EXPECT_EQ(v_assign.x, 0x0000);
v_assign = 0.5f;
EXPECT_EQ(v_assign.x, 0x3f00);
v_assign = 0.33333;
EXPECT_EQ(v_assign.x, 0x3eaa);
v_assign = -1;
EXPECT_EQ(v_assign.x, 0xbf80);
// Conversion operator
EXPECT_EQ(static_cast<float>(bfloat16(0.5f)), 0.5f);
EXPECT_NEAR(static_cast<double>(bfloat16(0.33333)), 0.33333, 0.01);
EXPECT_EQ(static_cast<int>(bfloat16(-1)), -1);
EXPECT_EQ(static_cast<bool>(bfloat16(true)), true);
}
TEST(bfloat16, arithmetic_cpu) {
EXPECT_NEAR(static_cast<float>(bfloat16(1) + bfloat16(1)), 2, 0.001);
EXPECT_EQ(static_cast<float>(bfloat16(5) + bfloat16(-5)), 0);
EXPECT_NEAR(static_cast<float>(bfloat16(0.33333f) + bfloat16(0.66667f)), 1.0f,
0.01);
EXPECT_EQ(static_cast<float>(bfloat16(3) - bfloat16(5)), -2);
EXPECT_NEAR(static_cast<float>(bfloat16(0.66667f) - bfloat16(0.33333f)),
0.33334f, 0.01);
EXPECT_NEAR(static_cast<float>(bfloat16(3.3f) * bfloat16(2.0f)), 6.6f, 0.01);
EXPECT_NEAR(static_cast<float>(bfloat16(-2.1f) * bfloat16(-3.0f)), 6.3f, 0.1);
EXPECT_NEAR(static_cast<float>(bfloat16(2.0f) / bfloat16(3.0f)), 0.66667f,
0.01);
EXPECT_EQ(static_cast<float>(bfloat16(1.0f) / bfloat16(2.0f)), 0.5f);
EXPECT_EQ(static_cast<float>(-bfloat16(512.0f)), -512.0f);
EXPECT_EQ(static_cast<float>(-bfloat16(-512.0f)), 512.0f);
}
TEST(bfloat16, comparison_cpu) {
EXPECT_TRUE(bfloat16(1.0f) == bfloat16(1.0f));
EXPECT_FALSE(bfloat16(-1.0f) == bfloat16(-0.5f));
EXPECT_TRUE(bfloat16(1.0f) != bfloat16(0.5f));
EXPECT_FALSE(bfloat16(-1.0f) != bfloat16(-1.0f));
EXPECT_TRUE(bfloat16(1.0f) < bfloat16(2.0f));
EXPECT_FALSE(bfloat16(-1.0f) < bfloat16(-1.0f));
EXPECT_TRUE(bfloat16(1.0f) <= bfloat16(1.0f));
EXPECT_TRUE(bfloat16(2.0f) > bfloat16(1.0f));
EXPECT_FALSE(bfloat16(-2.0f) > bfloat16(-2.0f));
EXPECT_TRUE(bfloat16(2.0f) >= bfloat16(2.0f));
}
TEST(bfloat16, lod_tensor_cpu) {
framework::LoDTensor lod_tensor;
std::vector<bfloat16> input_data = {bfloat16(1.0f), bfloat16(0.5f),
bfloat16(0.33333f), bfloat16(0.0f)};
EXPECT_EQ(input_data[0].x, 0x3f80);
EXPECT_EQ(input_data[1].x, 0x3f00);
EXPECT_EQ(input_data[2].x, 0x3eaa);
EXPECT_EQ(input_data[3].x, 0x0000);
lod_tensor.Resize({4, 1});
lod_tensor.set_lod(framework::LoD({{0, 2, 4}}));
bfloat16* data_ptr = lod_tensor.mutable_data<bfloat16>(CPUPlace());
EXPECT_NE(data_ptr, nullptr);
EXPECT_EQ(input_data.size(), static_cast<size_t>(lod_tensor.numel()));
for (size_t i = 0; i < input_data.size(); ++i) {
data_ptr[i] = input_data[i];
EXPECT_EQ(data_ptr[i].x, input_data[i].x);
}
}
TEST(bfloat16, floating) {
// compile time assert.
PADDLE_ENFORCE_EQ(
std::is_floating_point<bfloat16>::value, true,
platform::errors::Fatal("std::is_floating_point with bfloat16 data type "
"should be equal to true but it is not"));
}
TEST(bfloat16, print) {
bfloat16 a = bfloat16(1.0f);
std::cout << a << std::endl;
}
// CPU test
TEST(bfloat16, isinf) {
bfloat16 a;
a.x = 0x7f80;
bfloat16 b = bfloat16(INFINITY);
bfloat16 c = static_cast<bfloat16>(INFINITY);
EXPECT_EQ(std::isinf(a), true);
EXPECT_EQ(std::isinf(b), true);
EXPECT_EQ(std::isinf(c), true);
}
TEST(bfloat16, isnan) {
bfloat16 a;
a.x = 0x7fff;
bfloat16 b = bfloat16(NAN);
bfloat16 c = static_cast<bfloat16>(NAN);
EXPECT_EQ(std::isnan(a), true);
EXPECT_EQ(std::isnan(b), true);
EXPECT_EQ(std::isnan(c), true);
}
} // namespace platform
} // namespace paddle
......@@ -273,11 +273,116 @@ class ScopedTensorDescriptor {
groups);
}
inline cudnnTensorDescriptor_t descriptor(const cudnnDataType_t cudnn_type,
const std::vector<int>& dim,
const std::vector<int>& stride) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptor(
desc_, cudnn_type, dim.size(), dim.data(), stride.data()));
return desc_;
}
template <typename T>
inline cudnnTensorDescriptor_t descriptor(const std::vector<int>& dim,
const std::vector<int>& stride) {
return descriptor(CudnnDataType<T>::type, dim, stride);
}
private:
cudnnTensorDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor);
};
class ScopedRNNTensorDescriptor {
public:
ScopedRNNTensorDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateRNNDataDescriptor(&desc_));
}
~ScopedRNNTensorDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyRNNDataDescriptor(desc_));
}
inline cudnnRNNDataDescriptor_t descriptor(
const cudnnDataType_t cudnn_type, int max_seq_length, int batch_size,
int input_size, bool time_major, const std::vector<int>& seq_length) {
static float padding_fill = 0.0f;
cudnnRNNDataLayout_t layout;
if (time_major) {
layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
} else {
layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetRNNDataDescriptor(
desc_, cudnn_type, layout, max_seq_length, batch_size, input_size,
seq_length.data(), static_cast<void*>(&padding_fill)));
return desc_;
}
template <typename T>
inline cudnnRNNDataDescriptor_t descriptor(
int max_length, int batch_size, int input_size, bool time_major,
const std::vector<int>& seq_length) {
return descriptor(CudnnDataType<T>::type, max_length, batch_size,
input_size, time_major, seq_length);
}
private:
cudnnRNNDataDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedRNNTensorDescriptor);
};
class ScopedDropoutDescriptor {
public:
ScopedDropoutDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateDropoutDescriptor(&desc_));
}
~ScopedDropoutDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyDropoutDescriptor(desc_));
}
inline cudnnDropoutDescriptor_t descriptor(const cudnnHandle_t& handle,
const platform::Place& place,
bool initialized,
float dropout_prob_,
framework::Tensor* dropout_state_,
int seed, size_t state_size) {
auto* dropout_state_data = dropout_state_->data<uint8_t>();
if (!initialized) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetDropoutDescriptor(
desc_, handle, dropout_prob_, dropout_state_data, state_size, seed));
} else {
auto dropout_state_dims = dropout_state_->dims();
state_size = dropout_state_dims[0];
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnRestoreDropoutDescriptor(
desc_, handle, dropout_prob_, dropout_state_data, state_size, 0));
}
return desc_;
}
private:
cudnnDropoutDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedDropoutDescriptor);
};
class ScopedRNNDescriptor {
public:
ScopedRNNDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateRNNDescriptor(&desc_));
}
~ScopedRNNDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyRNNDescriptor(desc_));
}
inline cudnnRNNDescriptor_t descriptor() { return desc_; }
private:
cudnnRNNDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedRNNDescriptor);
};
class ScopedFilterDescriptor {
public:
ScopedFilterDescriptor() {
......@@ -319,6 +424,167 @@ class ScopedFilterDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};
class ScopedRNNBase {
public:
ScopedRNNBase(int seq_length, int batch_size, int input_size, int hidden_size,
int num_layers, float dropout_prob, int seed, int weight_numel,
bool initialized, bool is_bidirec)
: seq_length_(seq_length),
batch_size_(batch_size),
input_size_(input_size),
hidden_size_(hidden_size),
num_layers_(num_layers),
dropout_prob_(dropout_prob),
seed_(seed),
weight_numel_(weight_numel),
initialized_(initialized),
is_bidirec_(is_bidirec) {}
template <typename T>
void Create(const cudnnHandle_t& handle, const platform::Place& place,
std::vector<int> sequence_length, size_t* workspace_size,
size_t* reserve_size, framework::Tensor* dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
cudnnDataType_t cudnn_type = platform::CudnnDataType<T>::type;
// ------------------- cudnn x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
std::vector<int> strides_x = {input_size_, 1, 1};
std::vector<int> dims_y = {batch_size_, hidden_size_ * numDirections, 1};
std::vector<int> strides_y = {hidden_size_ * numDirections, 1, 1};
for (int i = 0; i < seq_length_; ++i) {
x_desc_.emplace_back(x_d.descriptor<T>(dims_x, strides_x));
y_desc_.emplace_back(y_d.descriptor<T>(dims_y, strides_y));
}
if (!sequence_length.empty()) {
x_seq_desc_ = x_seq_d.descriptor<T>(seq_length_, batch_size_, input_size_,
true, sequence_length);
y_seq_desc_ = y_seq_d.descriptor<T>(seq_length_, batch_size_,
hidden_size_ * numDirections, true,
sequence_length);
}
// ------------------- cudnn hx, hy, cx, cy descriptors----------
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
hidden_size_};
std::vector<int> strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1};
hx_desc_ = hx_d.descriptor<T>(dims_hx, strides_hx);
cx_desc_ = cx_d.descriptor<T>(dims_hx, strides_hx);
hy_desc_ = hy_d.descriptor<T>(dims_hx, strides_hx);
cy_desc_ = cy_d.descriptor<T>(dims_hx, strides_hx);
// ------------------- cudnn dropout descriptors ---------------------
size_t state_size;
if (!initialized_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
}
dropout_desc_ =
dropout_d.descriptor(handle, place, initialized_, dropout_prob_,
dropout_state, seed_, state_size);
// ------------------- cudnn rnn descriptors ---------------------
rnn_desc_ = rnn_d.descriptor();
#if CUDNN_VERSION >= 6000
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, cudnn_type));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
cudnn_type));
#endif
if (!sequence_length.empty()) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED));
}
// ------------------- cudnn weights_size ---------------------
size_t weights_size_;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type));
PADDLE_ENFORCE_EQ(
weights_size_, sizeof(T) * weight_numel_,
platform::errors::InvalidArgument(
"The cudnn lstm and setting weight size should be same."));
// ------------------- cudnn weight descriptors ---------------------
platform::DataLayout layout = platform::DataLayout::kNCHW;
int dim_tmp = weights_size_ / sizeof(T);
std::vector<int> dim_w = {dim_tmp, 1, 1};
w_desc_ = w_d.descriptor<T>(layout, dim_w);
// ------------------- cudnn workspace, reserve size ---------------------
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_, seq_length_, x_desc_.data(), workspace_size));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_, seq_length_, x_desc_.data(), reserve_size));
}
cudnnTensorDescriptor_t* x_desc() { return x_desc_.data(); }
cudnnTensorDescriptor_t* y_desc() { return y_desc_.data(); }
cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_; }
cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_; }
cudnnTensorDescriptor_t hx_desc() { return hx_desc_; }
cudnnTensorDescriptor_t cx_desc() { return cx_desc_; }
cudnnTensorDescriptor_t hy_desc() { return hy_desc_; }
cudnnTensorDescriptor_t cy_desc() { return cy_desc_; }
cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_; }
cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_; }
cudnnFilterDescriptor_t w_desc() { return w_desc_; }
private:
int seq_length_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
float dropout_prob_;
int seed_;
int weight_numel_;
bool initialized_;
bool is_bidirec_;
std::vector<cudnnTensorDescriptor_t> x_desc_;
std::vector<cudnnTensorDescriptor_t> y_desc_;
cudnnRNNDataDescriptor_t x_seq_desc_;
cudnnRNNDataDescriptor_t y_seq_desc_;
// A tensor descriptor describing the initial hidden state of the RNN.
cudnnTensorDescriptor_t hx_desc_;
// A tensor descriptor describing the initial cell state for LSTM networks.
cudnnTensorDescriptor_t cx_desc_;
// A tensor descriptor describing the final hidden state of the RNN.
cudnnTensorDescriptor_t hy_desc_;
// A tensor descriptor describing the final cell state for LSTM networks.
cudnnTensorDescriptor_t cy_desc_;
cudnnDropoutDescriptor_t dropout_desc_;
cudnnFilterDescriptor_t w_desc_;
cudnnRNNDescriptor_t rnn_desc_;
ScopedTensorDescriptor x_d;
ScopedTensorDescriptor y_d;
ScopedRNNTensorDescriptor x_seq_d;
ScopedRNNTensorDescriptor y_seq_d;
ScopedTensorDescriptor hx_d;
ScopedTensorDescriptor cx_d;
ScopedTensorDescriptor hy_d;
ScopedTensorDescriptor cy_d;
ScopedDropoutDescriptor dropout_d;
ScopedFilterDescriptor w_d;
ScopedRNNDescriptor rnn_d;
};
class ScopedConvolutionDescriptor {
public:
ScopedConvolutionDescriptor() {
......
......@@ -38,14 +38,15 @@ extern void *cublas_dso_handle;
*/
#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
using FUNC_TYPE = decltype(&::__name); \
template <typename... Args> \
inline cublasStatus_t operator()(Args... args) { \
inline auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using cublas_func = \
decltype(::__name(std::declval<Args>()...)) (*)(Args...); \
std::call_once(cublas_dso_flag, []() { \
cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
}); \
static void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \
return reinterpret_cast<cublas_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
......
......@@ -101,6 +101,9 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDataDescriptor); \
__macro(cudnnDestroyRNNDataDescriptor); \
__macro(cudnnSetRNNDataDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
......@@ -109,6 +112,11 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx); \
__macro(cudnnRNNForwardInferenceEx); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetTensorNdDescriptorEx);
......
......@@ -161,6 +161,12 @@ inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::data_type::u8;
}
template <>
inline mkldnn::memory::data_type
MKLDNNGetDataType<paddle::platform::bfloat16>() {
return mkldnn::memory::data_type::bf16;
}
inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
const mkldnn::engine& engine) {
auto reorder_prim = mkldnn::reorder(src, dst);
......
......@@ -59,6 +59,7 @@ void BindGenerator(py::module* m_ptr) {
.def_property("_is_init_py", &framework::Generator::GetIsInitPy,
&framework::Generator::SetIsInitPy);
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
} // end Generator
} // end namespace pybind
m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator);
}
} // namespace pybind
} // namespace paddle
......@@ -60,6 +60,9 @@ void BindAnalysisConfig(py::module *m);
void BindAnalysisPredictor(py::module *m);
void BindZeroCopyTensor(py::module *m);
void BindPaddlePassBuilder(py::module *m);
void BindPaddleInferPredictor(py::module *m);
void BindPaddleInferTensor(py::module *m);
void BindPredictorPool(py::module *m);
#ifdef PADDLE_WITH_MKLDNN
void BindMkldnnQuantizerConfig(py::module *m);
......@@ -139,6 +142,15 @@ void ZeroCopyTensorCreate(ZeroCopyTensor &tensor, // NOLINT
tensor.copy_from_cpu(static_cast<const T *>(data.data()));
}
template <typename T>
void PaddleInferTensorCreate(paddle_infer::Tensor &tensor, // NOLINT
py::array_t<T> data) {
std::vector<int> shape;
std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
tensor.Reshape(std::move(shape));
tensor.CopyFromCpu(static_cast<const T *>(data.data()));
}
size_t PaddleGetDTypeSize(PaddleDType dt) {
size_t size{0};
switch (dt) {
......@@ -183,6 +195,30 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
return array;
}
py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT
py::dtype dt = PaddleDTypeToNumpyDType(tensor.type());
auto tensor_shape = tensor.shape();
py::array::ShapeContainer shape(tensor_shape.begin(), tensor_shape.end());
py::array array(dt, std::move(shape));
switch (tensor.type()) {
case PaddleDType::INT32:
tensor.CopyToCpu(static_cast<int32_t *>(array.mutable_data()));
break;
case PaddleDType::INT64:
tensor.CopyToCpu(static_cast<int64_t *>(array.mutable_data()));
break;
case PaddleDType::FLOAT32:
tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data()));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64 and "
"FLOAT32."));
}
return array;
}
py::bytes SerializePDTensorToBytes(PaddleTensor &tensor) { // NOLINT
std::stringstream ss;
paddle::inference::SerializePDTensorToStream(&ss, tensor);
......@@ -200,8 +236,11 @@ void BindInferenceApi(py::module *m) {
BindNativePredictor(m);
BindAnalysisConfig(m);
BindAnalysisPredictor(m);
BindPaddleInferPredictor(m);
BindZeroCopyTensor(m);
BindPaddleInferTensor(m);
BindPaddlePassBuilder(m);
BindPredictorPool(m);
#ifdef PADDLE_WITH_MKLDNN
BindMkldnnQuantizerConfig(m);
#endif
......@@ -209,8 +248,17 @@ void BindInferenceApi(py::module *m) {
&paddle::CreatePaddlePredictor<AnalysisConfig>, py::arg("config"));
m->def("create_paddle_predictor",
&paddle::CreatePaddlePredictor<NativeConfig>, py::arg("config"));
m->def("create_predictor", [](const paddle_infer::Config &config)
-> std::unique_ptr<paddle_infer::Predictor> {
auto pred =
std::unique_ptr<paddle_infer::Predictor>(
new paddle_infer::Predictor(config));
return std::move(pred);
});
m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
m->def("get_version", &paddle_infer::GetVersion);
m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType);
}
namespace {
......@@ -525,6 +573,19 @@ void BindAnalysisPredictor(py::module *m) {
py::arg("dir"));
}
void BindPaddleInferPredictor(py::module *m) {
py::class_<paddle_infer::Predictor>(*m, "PaddleInferPredictor")
.def(py::init<const paddle_infer::Config &>())
.def("get_input_names", &paddle_infer::Predictor::GetInputNames)
.def("get_output_names", &paddle_infer::Predictor::GetOutputNames)
.def("get_input_handle", &paddle_infer::Predictor::GetInputHandle)
.def("get_output_handle", &paddle_infer::Predictor::GetOutputHandle)
.def("run", &paddle_infer::Predictor::Run)
.def("clone", &paddle_infer::Predictor::Clone)
.def("clear_intermediate_tensor",
&paddle_infer::Predictor::ClearIntermediateTensor);
}
void BindZeroCopyTensor(py::module *m) {
py::class_<ZeroCopyTensor>(*m, "ZeroCopyTensor")
.def("reshape", &ZeroCopyTensor::Reshape)
......@@ -538,6 +599,26 @@ void BindZeroCopyTensor(py::module *m) {
.def("type", &ZeroCopyTensor::type);
}
void BindPaddleInferTensor(py::module *m) {
py::class_<paddle_infer::Tensor>(*m, "PaddleInferTensor")
.def("reshape", &paddle_infer::Tensor::Reshape)
.def("copy_from_cpu", &PaddleInferTensorCreate<int32_t>)
.def("copy_from_cpu", &PaddleInferTensorCreate<int64_t>)
.def("copy_from_cpu", &PaddleInferTensorCreate<float>)
.def("copy_to_cpu", &PaddleInferTensorToNumpy)
.def("shape", &paddle_infer::Tensor::shape)
.def("set_lod", &paddle_infer::Tensor::SetLoD)
.def("lod", &paddle_infer::Tensor::lod)
.def("type", &paddle_infer::Tensor::type);
}
void BindPredictorPool(py::module *m) {
py::class_<paddle_infer::services::PredictorPool>(*m, "PredictorPool")
.def(py::init<const paddle_infer::Config &, size_t>())
.def("retrive", &paddle_infer::services::PredictorPool::Retrive,
py::return_value_policy::reference);
}
void BindPaddlePassBuilder(py::module *m) {
py::class_<PaddlePassBuilder>(*m, "PaddlePassBuilder")
.def(py::init<const std::vector<std::string> &>())
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "pybind11/numpy.h"
......@@ -104,6 +105,7 @@ struct ValidDTypeToPyArrayChecker {
}
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
......@@ -119,6 +121,9 @@ inline std::string TensorDTypeToPyDTypeStr(
if (type == proto_type) { \
if (std::is_same<T, platform::float16>::value) { \
return "e"; \
} else if (std::is_same<T, platform::bfloat16>::value) { \
/* NumPy character code of uint16 due to no support for bfloat16 */ \
return "H"; \
} else { \
constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
PADDLE_ENFORCE_EQ( \
......@@ -262,10 +267,10 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<uint16_t>>(array)) {
// TODO(cql): temporary keeping uint16, which is used for casting float16
// before. It should be depracated later.
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy);
// since there is still no support for bfloat16 in NumPy,
// uint16 is used for casting bfloat16
SetTensorFromPyArrayT<paddle::platform::bfloat16, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<bool>>(array)) {
SetTensorFromPyArrayT<bool, P>(self, array, place, zero_copy);
} else {
......@@ -479,6 +484,8 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self,
switch (src_type) {
case framework::proto::VarType::FP16:
return _sliceAndConcat<paddle::platform::float16>(self, obj, dim);
case framework::proto::VarType::BF16:
return _sliceAndConcat<paddle::platform::bfloat16>(self, obj, dim);
case framework::proto::VarType::FP32:
return _sliceAndConcat<float>(self, obj, dim);
case framework::proto::VarType::FP64:
......
......@@ -20,13 +20,12 @@ rem Paddle CI Task On Windows Platform
rem =================================================
rem -------clean up environment-----------
wmic process where name="op_function_generator.exe" call terminate 2>NUL
set work_dir=%cd%
if exist build rmdir build /s/q
mkdir build
cd /d build
tree .
dir paddle\fluid\pybind\Release
taskkill /f /im op_function_generator.exe 2>NUL
rem ------initialize the virtual environment------
if not defined PYTHON_ROOT set PYTHON_ROOT=C:\Python37
......@@ -59,7 +58,7 @@ if not defined WITH_AVX set WITH_AVX=ON
if not defined WITH_TESTING set WITH_TESTING=ON
if not defined WITH_PYTHON set WITH_PYTHON=ON
if not defined ON_INFER set ON_INFER=ON
if not defined WITH_INFERENCE_API_TEST set WITH_INFERENCE_API_TEST=OFF
if not defined WITH_INFERENCE_API_TEST set WITH_INFERENCE_API_TEST=ON
if not defined WITH_TPCACHE set WITH_TPCACHE=ON
rem ------set cache third_party------
......@@ -243,7 +242,7 @@ dir %THIRD_PARTY_PATH:/=\%\install\mkldnn\bin
dir %THIRD_PARTY_PATH:/=\%\install\warpctc\bin
set PATH=%THIRD_PARTY_PATH:/=\%\install\openblas\lib;%THIRD_PARTY_PATH:/=\%\install\openblas\bin;%THIRD_PARTY_PATH:/=\%\install\zlib\bin;%THIRD_PARTY_PATH:/=\%\install\mklml\lib;%THIRD_PARTY_PATH:/=\%\install\mkldnn\bin;%THIRD_PARTY_PATH:/=\%\install\warpctc\bin;%PATH%
ctest.exe --output-on-failure -C Release -j 8
ctest.exe --output-on-failure -C Release -j 8 --repeat until-pass:4
goto:eof
:unit_test_error
......@@ -402,7 +401,7 @@ taskkill /f /im git-remote-https.exe 2>NUL
taskkill /f /im vctip.exe 2>NUL
taskkill /f /im cvtres.exe 2>NUL
taskkill /f /im rc.exe 2>NUL
taskkill /f /im op_function_generator.exe 2>NUL
wmic process where name="op_function_generator.exe" call terminate 2>NUL
taskkill /f /im python.exe 2>NUL
call paddle_winci\Scripts\deactivate.bat 2>NUL
taskkill /f /im python.exe 2>NUL
......
......@@ -528,8 +528,50 @@ EOF
elif [ "$1" == "cp37-cp37m" ]; then
pip3.7 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl
fi
tmpfile_rand=`date +%s%N`
tmpfile=$tmp_dir/$tmpfile_rand
set +e
ut_startTime_s=`date +%s`
ctest --output-on-failure -j $2;mactest_error=$?
ctest --output-on-failure -j $2 | tee $tmpfile
failed_test_lists=''
collect_failed_tests
set +x
mactest_error=0
retry_unittests_record=''
retry_time=3
exec_times=0
exec_time_array=('first' 'second' 'third')
if [ -n "$failed_test_lists" ];then
mactest_error=1
while ( [ $exec_times -lt $retry_time ] && [ -n "${failed_test_lists}" ] )
do
retry_unittests_record="$retry_unittests_record$failed_test_lists"
failed_test_lists_ult=`echo "${failed_test_lists}"`
read retry_unittests <<< $(echo "$failed_test_lists" | grep -oEi "\-.+\(" | sed 's/(//' | sed 's/- //' )
echo "========================================="
echo "This is the ${exec_time_array[$exec_times]} time to re-run"
echo "========================================="
echo "The following unittest will be re-run:"
echo "${retry_unittests}"
echo "========================================="
retry_unittests_regular=''
for line in ${retry_unittests[@]} ;
do
if [[ "$retry_unittests_regular" == "" ]];then
retry_unittests_regular="^$line$"
else
retry_unittests_regular="$retry_unittests_regular|^$line$"
fi
done
rm -f $tmp_dir/*
failed_test_lists=''
ctest -R "($retry_unittests_regular)" --output-on-failure -j $2 | tee $tmpfile
collect_failed_tests
exec_times=$[$exec_times+1]
done
fi
#mactest_error=$?
ut_endTime_s=`date +%s`
echo "Mac testCase Time: $[ $ut_endTime_s - $ut_startTime_s ]s"
paddle version
......@@ -537,7 +579,21 @@ EOF
export http_proxy=$my_proxy
export https_proxy=$my_proxy
if [ "$mactest_error" != 0 ];then
exit 8;
if [[ "$failed_test_lists" == "" ]]; then
echo "========================================"
echo "There are failed tests, which have been successful after re-run:"
echo "========================================"
echo "The following tests have been re-ran:"
echo "${retry_unittests_record}"
else
failed_test_lists_ult=`echo "${failed_test_lists}"`
echo "========================================"
echo "Summary Failed Tests... "
echo "========================================"
echo "The following tests FAILED: "
echo "${failed_test_lists_ult}"
exit 8;
fi
fi
fi
}
......@@ -561,6 +617,7 @@ function fetch_upstream_develop_if_not_exist() {
function generate_upstream_develop_api_spec() {
fetch_upstream_develop_if_not_exist
cur_branch=`git branch | grep \* | cut -d ' ' -f2`
git checkout .
git checkout -b develop_base_pr upstream/$BRANCH
cmake_gen $1
build $2
......@@ -1421,6 +1478,7 @@ function main() {
init
if [ "$CMD" != "assert_file_approvals" ];then
python ${PADDLE_ROOT}/tools/summary_env.py
bash ${PADDLE_ROOT}/tools/get_cpu_info.sh
fi
case $CMD in
build_only)
......
......@@ -217,6 +217,8 @@ from .tensor.search import index_select #DEFINE_ALIAS
from .tensor.search import nonzero #DEFINE_ALIAS
from .tensor.search import sort #DEFINE_ALIAS
from .framework.random import manual_seed #DEFINE_ALIAS
from .framework.random import get_cuda_rng_state #DEFINE_ALIAS
from .framework.random import set_cuda_rng_state #DEFINE_ALIAS
from .framework import Variable #DEFINE_ALIAS
from .framework import ParamAttr #DEFINE_ALIAS
from .framework import create_global_var #DEFINE_ALIAS
......@@ -230,6 +232,7 @@ from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import SaveLoadConfig #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS
from .framework import NoamDecay #DEFINE_ALIAS
......@@ -267,5 +270,6 @@ from . import static
# high-level api
from .hapi import Model
from .hapi import callbacks
from .hapi import summary
import paddle.text
import paddle.vision
......@@ -74,7 +74,8 @@ def load_data(filename, feature_num=14, ratio=0.8):
data = data.reshape(data.shape[0] // feature_num, feature_num)
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum(
axis=0) / data.shape[0]
feature_range(maximums[:-1], minimums[:-1])
# if you want to print the distribution of input data, you could use function of feature_range
#feature_range(maximums[:-1], minimums[:-1])
for i in six.moves.range(feature_num - 1):
data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i])
offset = int(data.shape[0] * ratio)
......
......@@ -50,3 +50,10 @@ distributed_optimizer = fleet.distributed_optimizer
save_inference_model = fleet.save_inference_model
save_persistables = fleet.save_persistables
minimize = fleet.minimize
distributed_model = fleet.distributed_model
step = fleet.step
clear_grad = fleet.clear_grad
set_lr = fleet.set_lr
get_lr = fleet.get_lr
state_dict = fleet.state_dict
set_state_dict = fleet.set_state_dict
......@@ -13,7 +13,11 @@
# limitations under the License.
from __future__ import print_function
import copy
import warnings
import paddle
from paddle.fluid.framework import dygraph_only
from paddle.fluid import compiler
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
from .strategy_compiler import StrategyCompiler
from .distributed_strategy import DistributedStrategy
......@@ -21,6 +25,7 @@ from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
from .util_factory import UtilFactory
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.fluid.dygraph import parallel_helper
def _inited_runtime_handler_(func):
......@@ -35,7 +40,24 @@ def _inited_runtime_handler_(func):
return __impl__
def _is_non_distributed_check_(func):
def __impl__(*args, **kwargs):
cls = args[0]
if cls._role_maker is not None and cls._role_maker._is_non_distributed(
) is True:
warnings.warn(
"%s() function doesn't work when use non_distributed fleet." %
(func.__name__))
return
return func(*args, **kwargs)
return __impl__
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
class Fleet(object):
......@@ -159,6 +181,12 @@ class Fleet(object):
"`role_maker` should be subclass of `RoleMakerBase`, but got {}".
format(type(role_maker)))
self.strategy_compiler = StrategyCompiler()
if paddle.fluid.framework.in_dygraph_mode():
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn(
"The dygraph parallel environment has been initialized.")
else:
paddle.distributed.init_parallel_env()
return None
def is_first_worker(self):
......@@ -367,6 +395,7 @@ class Fleet(object):
"""
self._role_maker.barrier_worker()
@is_non_distributed_check
@inited_runtime_handler
def init_worker(self):
"""
......@@ -391,6 +420,7 @@ class Fleet(object):
"""
self._runtime_handle._init_worker()
@is_non_distributed_check
@inited_runtime_handler
def init_server(self, *args, **kwargs):
"""
......@@ -416,6 +446,7 @@ class Fleet(object):
"""
self._runtime_handle._init_server(*args, **kwargs)
@is_non_distributed_check
@inited_runtime_handler
def run_server(self):
"""
......@@ -440,6 +471,7 @@ class Fleet(object):
"""
self._runtime_handle._run_server()
@is_non_distributed_check
@inited_runtime_handler
def stop_worker(self):
"""
......@@ -564,12 +596,344 @@ class Fleet(object):
"""
self.user_defined_optimizer = optimizer
if paddle.fluid.framework.in_dygraph_mode():
return self
if strategy == None:
strategy = DistributedStrategy()
self.user_defined_strategy = strategy
self.valid_strategy = None
return self
@dygraph_only
def distributed_model(self, model):
"""
Return dygraph distributed data parallel model (Layer)
Only work in dygraph mode
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize fleet environment
fleet.init(is_collective=True)
# 3. create layer & optimizer
layer = LinearNet()
loss_fn = nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=layer.parameters())
# 4. get data_parallel model using fleet
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
# 5. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
assert model is not None
self.model = paddle.DataParallel(model)
return self.model
@dygraph_only
def state_dict(self):
"""
Get state dict information from optimizer.
Only work in dygraph mode
Returns:
state_dict(dict) : dict contains all the Tensor used by optimizer
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict()
"""
# imitate target optimizer retrieval
return self.user_defined_optimizer.state_dict()
@dygraph_only
def set_state_dict(self, state_dict):
"""
Load optimizer state dict.
Only work in dygraph mode
Args:
state_dict(dict) : Dict contains all the Tensor needed by optimizer
Returns: None
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict()
paddle.framework.save(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = paddle.framework.load( "paddle_dy")
adam.set_state_dict(opti_state_dict)
"""
# imitate target optimizer retrieval
return self.user_defined_optimizer.set_state_dict(state_dict)
@dygraph_only
def set_lr(self, value):
"""
Set the value of the learning rate manually in the optimizer.
Only work in dygraph mode
Args:
value (float|Tensor): the value of learning rate
Returns: None
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
for i in range(5):
adam.set_lr(lr_list[i])
lr = adam.get_lr()
print("current lr is {}".format(lr))
# Print:
# current lr is 0.2
# current lr is 0.3
# current lr is 0.4
# current lr is 0.5
# current lr is 0.6
"""
# imitate target optimizer retrieval
return self.user_defined_optimizer.set_lr(value)
@dygraph_only
def get_lr(self):
"""
Get current step learning rate.
Only work in dygraph mode
Returns:
float: The learning rate of the current step.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
lr = adam.get_lr()
print(lr) # 0.01
"""
# imitate target optimizer retrieval
return self.user_defined_optimizer.get_lr()
@dygraph_only
def step(self):
"""
Execute the optimizer once.
Only work in dygraph mode
Returns: None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize fleet environment
fleet.init(is_collective=True)
# 3. create layer & optimizer
layer = LinearNet()
loss_fn = nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=layer.parameters())
# 4. get data_parallel model using fleet
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
# 5. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
# imitate target optimizer retrieval
return self.user_defined_optimizer.step()
@dygraph_only
def clear_grad(self):
"""
Execute the optimizer once.
Only work in dygraph mode
Returns: None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize fleet environment
fleet.init(is_collective=True)
# 3. create layer & optimizer
layer = LinearNet()
loss_fn = nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=layer.parameters())
# 4. get data_parallel model using fleet
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
# 5. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
# imitate target optimizer retrieval
return self.user_defined_optimizer.clear_grad()
def minimize(self,
loss,
startup_program=None,
......@@ -593,8 +957,8 @@ class Fleet(object):
tuple: tuple (optimize_ops, params_grads), A list of operators appended
by minimize and a list of (param, grad) variable pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter.
The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples:
......@@ -619,6 +983,11 @@ class Fleet(object):
# for more examples, please reference https://github.com/PaddlePaddle/FleetX
"""
if paddle.fluid.framework.in_dygraph_mode():
# imitate target optimizer retrieval
target_opt = self.user_defined_optimizer
return target_opt.minimize(loss)
context = {}
# cache original feed forward program
self.origin_main_program = loss.block.program
......@@ -640,6 +1009,18 @@ class Fleet(object):
MetaOptimizerFactory()._get_valid_meta_optimizers(
self.user_defined_optimizer)
context["user_defined_strategy"] = copy.copy(self.user_defined_strategy)
# trigger the auto-parallel in very strict condition
# strategy = DistributedStrategy()
# strategy.auto = True
# optimizer = paddle.optimizer.SGD(learning_rate=0.1)
# optimizer = fleet.distributed_optimizer(optimizer, strategy)
if self.user_defined_strategy._is_strict_auto():
# turn on all the strategy for each optimizer
for opt in distributed_optimizer_list:
opt._enable_strategy(self.user_defined_strategy)
valid_optimizer_list = []
valid_graph_optimizer_list = []
can_not_apply_optimizer_list = []
......@@ -672,6 +1053,20 @@ class Fleet(object):
optimize_ops = []
params_grads = []
if self._role_maker._is_non_distributed() and not self._is_collective:
if self._runtime_handle is None:
self._runtime_handle = RuntimeFactory()._create_runtime(context)
compiled_program = compiler.CompiledProgram(
self.origin_main_program).with_data_parallel(
loss_name=loss.name, share_vars_from=None)
loss.block.program._graph = compiled_program
return self.user_defined_optimizer.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
if meta_optimizer:
optimize_ops, params_grads = meta_optimizer.minimize(
loss,
......
......@@ -232,6 +232,8 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._node_type_comm = None
self._all_comm = None
self._non_distributed = False
if not self._is_collective:
self._hdfs_name = kwargs.get("hdfs_name", "")
self._hdfs_ugi = kwargs.get("hdfs_ugi", "")
......@@ -373,6 +375,15 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self.generate_role()
return self._server_endpoints
def _is_non_distributed(self):
"""
Return True if indispensable environment for fleetrun is not found
(use python-run to launch fleet-code directly)
"""
if not self._role_is_generated:
self.generate_role()
return self._non_distributed
def _heter_worker_num(self):
"""
get heter worker nums
......@@ -409,13 +420,22 @@ class PaddleCloudRoleMaker(RoleMakerBase):
try:
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST",
"").split(",")
assert self._server_endpoints != ""
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
"").split(",")
assert self._server_endpoints != ""
if self._server_endpoints is None:
# back to non_distributed execution.
self._server_endpoints = ""
self._trainers_num = 1
self._role = Role.WORKER
self._current_id = 0
self._node_num = 1
self._heter_trainers_num = 0
self._heter_trainer_endpoints = None
self._non_distributed = True
return
self._server_endpoints = self._server_endpoints.split(",")
trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"])
training_role = os.environ["TRAINING_ROLE"]
......@@ -488,7 +508,11 @@ class PaddleCloudRoleMaker(RoleMakerBase):
assert (self._training_role == "TRAINER")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
self._cur_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
assert self._worker_endpoints is not None, "can't find PADDLE_TRAINER_ENDPOINTS"
if self._worker_endpoints is None:
# back to non_distributed execution.
self._worker_endpoints = "127.0.0.1:6170"
self._cur_endpoint = self._worker_endpoints
self._non_distributed = True
self._worker_endpoints = self._worker_endpoints.split(",")
self._trainers_num = len(self._worker_endpoints)
self._node_num = len(
......
......@@ -200,11 +200,11 @@ def launch_collective(args):
start_port = os.environ.get('FLAGS_START_PORT')
if cloud_utils.use_paddlecloud() and trainers_num != 1:
cluster, pod = cloud_utils.get_cloud_cluster(args.ips, gpus, start_port)
logger.info("get cluster from cloud:{}".format(cluster))
logger.debug("get cluster from cloud:{}".format(cluster))
else:
# trainers_num = 1 or not use paddlecloud ips="a,b"
cluster, pod = get_cluster_from_args(args, gpus)
logger.info("get cluster from args:{}".format(cluster))
logger.debug("get cluster from args:{}".format(cluster))
procs = start_local_trainers(
cluster,
......@@ -217,7 +217,8 @@ def launch_collective(args):
alive = watch_local_trainers(procs, cluster.trainers_nranks())
if not alive:
logger.info("Local procs complete, POD info:{}".format(pod))
logger.info("Local processes completed.")
logger.debug("POD info:{}".format(pod))
break
time.sleep(3)
......@@ -313,18 +314,26 @@ def launch_ps(args):
cmds = []
log_fns = []
for idx, cur_server in enumerate(pod.servers):
current_env.update({
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": server_endpoints,
"PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": cur_server.endpoint.split(":")[0]
})
}
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
cmds.append(cmd)
if idx == 0:
logger.info(
"Local server start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.servers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value"))))
if args.log_dir is not None:
os.system("mkdir -p {}".format(args.log_dir))
fn = open("%s/serverlog.%d" % (args.log_dir, idx), "w")
......@@ -338,21 +347,32 @@ def launch_ps(args):
tp.rank = cur_server.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = 0 if fn else None
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
procs.append(tp)
for idx, cur_worker in enumerate(pod.workers):
current_env.update({
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": worker_endpoints,
"PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank)
})
}
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
cmds.append(cmd)
if idx == 0:
logger.info(
"Local worker start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.workers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value"))))
if args.log_dir is not None:
os.system("mkdir -p {}".format(args.log_dir))
fn = open("%s/workerlog.%d" % (args.log_dir, idx), "w")
......@@ -366,11 +386,14 @@ def launch_ps(args):
tp.rank = cur_worker.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = 0 if fn else None
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
procs.append(tp)
logger.info(
"Please check servers and workers logs in {}/workerlog.* and {}/serverlog.*".
format(args.log_dir, args.log_dir))
# only wait worker to finish here
for i, proc in enumerate(procs):
if i < len(pod.servers):
......@@ -403,16 +426,16 @@ def launch():
cuda_device_num = fluid.core.get_cuda_device_count()
if len(has_ps_args) > 0 or cuda_device_num == 0:
logger.info(
"Run parameter-sever cpu mode. pserver args:{}, cuda count:{}".
"Run parameter-sever cpu mode. pserver arguments:{}, cuda count:{}".
format(has_ps_args, cuda_device_num))
launch_ps(args)
elif len(has_collective_args) > 0:
logger.info("Run collective gpu mode. gpu args:{}, cuda count:{}".
logger.info("Run collective gpu mode. gpu arguments:{}, cuda count:{}".
format(has_collective_args, cuda_device_num))
launch_collective(args)
else:
logger.warning(
"Not found distinct args. Default use gpu collective mode")
"Not found distinct arguments. Default use gpu collective mode")
launch_collective(args)
......
......@@ -253,7 +253,8 @@ def terminate_local_procs(procs):
for p in procs:
if p.proc.poll() is None:
p.proc.terminate()
p.log_fn.close()
if p.log_fn:
p.log_fn.close()
logger.debug("terminate process id:{}".format(p.proc.pid))
#wait all process terminiated
......@@ -338,6 +339,45 @@ def get_ports(num, offset):
return ports
def pretty_print_envs(envs, header=None):
spacing = 2
max_k = 40
max_v = 45
for k, v in envs.items():
max_k = max(max_k, len(k))
h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
length = max_k + max_v + spacing
border = "".join(["="] * length)
line = "".join(["-"] * length)
draws = ""
draws += border + "\n"
if header:
draws += h_format.format(header[0], header[1])
else:
draws += h_format.format("fleetrun Distributed Envs", "Value")
draws += line + "\n"
for k, v in envs.items():
if isinstance(v, str) and len(v) >= max_v:
str_v = "... " + v[-41:]
else:
str_v = v
draws += l_format.format(k, " " * spacing, str(str_v))
draws += border
_str = "\n{}\n".format(draws)
return _str
class TrainerProc(object):
def __init__(self):
self.proc = None
......@@ -373,11 +413,19 @@ def start_local_trainers(cluster,
current_env.update(proc_env)
logger.debug("trainer proc env:{}".format(current_env))
cmd = [sys.executable, "-u", training_script] + training_script_args
logger.info("start trainer proc:{} env:{}".format(cmd, proc_env))
logger.debug("start trainer proc{} env:{}".format(cmd, current_env))
if idx == 0:
logger.info("Local start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.trainers),
pretty_print_envs(proc_env, ("Distributed Envs",
"Value"))))
logger.info(
"More details for debug about commands and environments are written in {}/run.sh".
format(log_dir))
fn = None
if log_dir is not None:
......
......@@ -42,6 +42,17 @@ class AMPOptimizer(MetaOptimizerBase):
dist_strategy.amp = False
dist_strategy.amp_configs = {}
def _enable_strategy(self, dist_strategy):
dist_strategy.amp = True
dist_strategy.amp_configs = {
"init_loss_scaling": 32768.0,
"incr_every_n_steps": 1000,
"decr_every_n_nan_or_inf": 2,
"incr_ratio": 2.0,
"decr_ratio": 8.0,
"use_dynamic_loss_scaling": True
}
def minimize_impl(self,
loss,
startup_program=None,
......
......@@ -69,6 +69,10 @@ class DGCOptimizer(MetaOptimizerBase):
dist_strategy.dgc = False
dist_strategy.dgc_configs = {}
def _enable_strategy(self, dist_strategy):
dist_strategy.dgc = True
dist_strategy.dgc_configs = {"rampup_begin_step": 0, "rampup_step": 1}
def backward(self,
loss,
startup_program=None,
......
......@@ -45,6 +45,10 @@ class GradientMergeOptimizer(MetaOptimizerBase):
dist_strategy.gradient_merge = False
dist_strategy.gradient_merge_configs = {}
def _enable_strategy(self, dist_strategy):
# we currently do not support auto-enable gradient merge
return
def minimize_impl(self,
loss,
startup_program=None,
......
......@@ -148,9 +148,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
sync_allreduce = dist_strategy.sync_nccl_allreduce
if sync_allreduce:
paddle.fluid.framework.set_flags({
"FLAGS_sync_nccl_allreduce": True
})
exe_strategy.num_threads = local_build_strategy.nccl_comm_num + 1
if local_build_strategy.use_hierarchical_allreduce:
exe_strategy.num_threads = 2 * local_build_strategy.nccl_comm_num + 1
......@@ -191,7 +188,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy):
# TODO(guru4elephant): should close all PE related flags here
pass
return
def _enable_strategy(self, dist_strategy):
# by default, graph execution strategy is enabled
return
def minimize(self,
loss,
......
......@@ -75,6 +75,13 @@ class LambOptimizer(MetaOptimizerBase):
dist_strategy.lamb = False
dist_strategy.lamb_configs = {}
def _enable_strategy(self, dist_strategy):
dist_strategy.lamb = True
dist_strategy.lamb_configs = {
"lamb_weight_decay": 0.01,
"exclude_from_weight_decay": []
}
def backward(self,
loss,
startup_program=None,
......
......@@ -59,6 +59,13 @@ class LarsOptimizer(MetaOptimizerBase):
dist_strategy.lars = False
dist_strategy.lars_configs = {}
def _enable_strategy(self, dist_strategy):
dist_strategy.lars = True
dist_strategy.lars_configs = {
"lars_coeff": 0.01,
"lars_weight_decay": 0.0005,
}
def backward(self,
loss,
startup_program=None,
......
......@@ -14,8 +14,8 @@
from __future__ import print_function
import paddle
from paddle.fluid import program_guard, layers, default_main_program
from paddle.fluid.optimizer import Momentum, SGD
from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op
......@@ -35,13 +35,19 @@ class LocalSGDOptimizer(MetaOptimizerBase):
if self.role_maker.worker_num() <= 1:
return False
return isinstance(self.inner_opt, Momentum) \
or isinstance(self.inner_opt, SGD)
return isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) \
or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \
or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \
or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD)
def _disable_strategy(self, dist_strategy):
dist_strategy.localsgd = False
dist_strategy.localsgd_configs = {}
def _enable_strategy(self, dist_strategy):
dist_strategy.localsgd = True
dist_strategy.localsgd_configs = {"k_steps": 1}
def snapshot_name(self, param_name):
return param_name + self.snapshot_key
......
......@@ -48,6 +48,10 @@ class MetaOptimizerBase(Optimizer):
raise NotImplementedError("you should implement disable strategy in {}".
format(type(self).__name__))
def _enable_strategy(self, dist_strategy):
raise NotImplementedError("you should implement enable strategy in {}".
format(type(self).__name__))
def apply_gradients(self, params_grads):
return self.inner_opt.apply_gradients(params_grads=params_grads)
......
......@@ -39,6 +39,11 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer):
def _disable_strategy(self, dist_strategy):
dist_strategy.a_sync_configs = {}
def _enable_strategy(self, dist_strategy):
# only open up the async mode for auto-parallel
dist_strategy.a_sync = True
dist_strategy.a_sync_configs = {}
def _is_graph_out(self):
return True
......
......@@ -157,4 +157,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
return None, None
def _disable_strategy(self, dist_strategy):
dist_strategy.a_sync_configs = {}
self.user_defined_strategy.a_sync_configs = {}
def _enable_strategy(self, dist_strategy):
dist_strategy.a_sync = True
dist_strategy.a_sync_configs = {}
......@@ -111,6 +111,10 @@ class PipelineOptimizer(MetaOptimizerBase):
dist_strategy.pipeline = False
dist_strategy.pipeline_configs = {}
def _enable_strategy(self, dist_strategy):
# we do not support enable pipeline automatically right now
return
def minimize_impl(self,
loss,
startup_program=None,
......
......@@ -49,6 +49,10 @@ class RecomputeOptimizer(MetaOptimizerBase):
dist_strategy.recompute = False
dist_strategy.recompute_configs = {}
def _enable_strategy(self, dist_strategy):
# we do not support automatically recompute checkpoints currently
return
def backward(self,
loss,
startup_program=None,
......
......@@ -202,6 +202,9 @@ class ParameterServerRuntime(RuntimeBase):
if self.role_maker._get_heter_worker_device() == "GPU":
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
executor = Executor(fluid.CUDAPlace(gpu_id))
elif self.role_maker._get_heter_worker_device() == "XPU":
xpu_id = int(os.getenv("FLAGS_selected_xpus", "0"))
executor = Executor(fluid.XPUPlace(xpu_id))
else:
raise ValueError("Not Support Device {}".format(
self.role_maker._get_heter_worker_device()))
......
此差异已折叠。
......@@ -74,7 +74,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
continue
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
if in_var.type not in valid_types:
if in_var.type not in valid_types or in_var.dtype == dest_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
......@@ -84,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=False)
stop_gradient=in_var.stop_gradient)
block._insert_op(
idx,
......@@ -100,7 +100,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32:
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
if op.type == 'batch_norm' and out_name != 'Y':
continue
......
......@@ -299,11 +299,14 @@ class Quant2Int8MkldnnPass(object):
# Convert int8 range weights to fp32 range weights
scales = self._weight_scales[output_var_name]
weight = self._load_param(self._scope, weight_var_name)
assert scales.size == 1 or scales.size == len(
weight
), "The size of weight scales vector ({}) does not match the number of output channels ({}) in the weights tensor {}.".format(
scales.size, len(weight), weight_var_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
if scales.size == 1 or scales.size == weight.shape[0]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
elif len(weight.shape) > 1 and scales.size == weight.shape[1]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
else:
raise ValueError(
"The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}."
.format(scales.size, weight.shape, weight_var_name))
w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32)
self._restore_var(weight_var_name, w_fp32)
......
......@@ -132,6 +132,28 @@ def check_dtype(input_dtype,
extra_message))
def check_shape(shape,
op_name,
expected_shape_type=(list, tuple, Variable),
expected_element_type=(int, Variable),
expected_tensor_dtype=('int32', 'int64')):
# See NOTE [ Why skip dynamic graph check ]
if in_dygraph_mode():
return
check_type(shape, 'shape', expected_shape_type, op_name)
if expected_element_type is not None and not isinstance(shape, Variable):
for item in shape:
check_type(item, 'element of shape', expected_element_type, op_name)
if expected_tensor_dtype is not None and isinstance(item, Variable):
check_dtype(
item.dtype, 'element of shape', expected_tensor_dtype,
op_name,
'If element of shape is Tensor, its data type should be {}'.
format(', '.join(expected_tensor_dtype)))
if expected_tensor_dtype is not None and isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', expected_tensor_dtype, op_name)
class DataToLoDTensorConverter(object):
def __init__(self, place, lod_level, shape, dtype):
self.place = place
......
......@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.layers import Layer
# TODO(liym27): A better way to do this.
......@@ -118,14 +119,9 @@ def convert_call(func):
func_self = None
converted_call = None
# Function in convert_call may be decorated by another `@declarative`,
# Function in convert_call may be decorated by another `@to_static`,
# in this case, unwraps it into a raw method or function.
if isinstance(func, StaticLayer):
instance = func._class_instance
if instance is not None:
func = func.dygraph_function.__get__(instance)
else:
func = func.dygraph_function
_, func = unwrap_decorators(func)
if is_builtin_len(func):
return convert_len
......@@ -155,7 +151,8 @@ def convert_call(func):
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticLayer):
global_functions.add(fn.dygraph_function)
_, fn = unwrap_decorators(fn)
global_functions.add(fn)
if func in global_functions:
converted_call = convert_to_static(func)
......@@ -189,7 +186,8 @@ def convert_call(func):
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
forward_func = convert_to_static(func.forward)
_, forward_func = unwrap_decorators(func.forward)
forward_func = convert_to_static(forward_func)
setattr(func, 'forward', forward_func)
func_self = func
except Exception:
......
此差异已折叠。
此差异已折叠。
......@@ -97,7 +97,7 @@ class LearningRateDecay(object):
"""
self.keys = ['step_num']
def set_dict(self, state_dict):
def set_state_dict(self, state_dict):
"""
Loads the schedulers state.
"""
......@@ -114,6 +114,9 @@ class LearningRateDecay(object):
"There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
)
# [aliases] Compatible with old method names
set_dict = set_state_dict
def step(self):
raise NotImplementedError()
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册