diff --git a/Dockerfile b/Dockerfile index 870304a6acc99e715dffbfabd8058be000b6872c..9ac58f37f2893613ca9f82be08136d9da674737e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,8 @@ # A image for building paddle binaries # Use cuda devel base image for both cpu and gpu environment + +# When you modify it, please be aware of cudnn-runtime version +# and libcudnn.so.x in paddle/scripts/docker/build.sh FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04 MAINTAINER PaddlePaddle Authors diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index b8847e4b909cbab67b2ddb6885b45b73d402de19..9f753478d8ecf12441d4b1745a9f6750a1038e31 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -146,6 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { return; } + need_update_ = true; ops_.erase(ops_.begin() + s, ops_.begin() + e); } diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/paddle/fluid/framework/data_device_transform_test.cu index a66525303da58601f85c40c41854edaf22c3d4ea..df4caa45eba2470f7528d2fbd99cca39cae0b596 100644 --- a/paddle/fluid/framework/data_device_transform_test.cu +++ b/paddle/fluid/framework/data_device_transform_test.cu @@ -103,9 +103,7 @@ static void BuildVar(const std::string& param_name, } TEST(Operator, CPUtoGPU) { - using namespace paddle::framework; - using namespace paddle::platform; - InitDevices(true); + paddle::framework::InitDevices(true); paddle::framework::Scope scope; paddle::platform::CPUPlace cpu_place; @@ -118,8 +116,9 @@ TEST(Operator, CPUtoGPU) { auto cpu_op = paddle::framework::OpRegistry::CreateOp(cpu_op_desc); // prepare input - auto* in_t = scope.Var("IN1")->GetMutable(); - auto* src_ptr = in_t->mutable_data({2, 3}, CPUPlace()); + auto* in_t = scope.Var("IN1")->GetMutable(); + auto* src_ptr = + in_t->mutable_data({2, 3}, paddle::platform::CPUPlace()); for (int i = 0; i < 2 * 3; ++i) { src_ptr[i] = static_cast(i); } @@ -128,7 +127,7 @@ TEST(Operator, CPUtoGPU) { auto* output = scope.Var("OUT1"); cpu_op->Run(scope, cpu_place); - auto* output_ptr = output->Get().data(); + auto* output_ptr = output->Get().data(); for (int i = 0; i < 2 * 3; ++i) { ASSERT_EQ(output_ptr[i], static_cast(i) * 2); } @@ -153,12 +152,14 @@ TEST(Operator, CPUtoGPU) { VLOG(3) << "after gpu_op run"; // auto* output2_ptr = output2->Get().data(); - DeviceContextPool& pool = DeviceContextPool::Instance(); + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); auto dev_ctx = pool.Get(cuda_place); paddle::framework::Tensor output_tensor; - TensorCopy(output2->Get(), paddle::platform::CPUPlace(), *dev_ctx, - &output_tensor); + paddle::framework::TensorCopy(output2->Get(), + paddle::platform::CPUPlace(), *dev_ctx, + &output_tensor); dev_ctx->Wait(); float* output2_ptr = output_tensor.data(); diff --git a/paddle/fluid/framework/data_layout_transform_test.cc b/paddle/fluid/framework/data_layout_transform_test.cc index dd17cac0e10db0d058d399cc725e18dcb14be507..a0d08826b854fea9256382f0e065fd59dda8c8b3 100644 --- a/paddle/fluid/framework/data_layout_transform_test.cc +++ b/paddle/fluid/framework/data_layout_transform_test.cc @@ -18,27 +18,28 @@ #include "paddle/fluid/platform/device_context.h" TEST(DataTransform, DataLayoutFunction) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto place = CPUPlace(); - Tensor in = Tensor(); - Tensor out = Tensor(); - in.mutable_data(make_ddim({2, 3, 1, 2}), place); - in.set_layout(DataLayout::kNHWC); - - auto kernel_nhwc = OpKernelType(proto::VarType::FP32, place, - DataLayout::kNHWC, LibraryType::kPlain); - auto kernel_ncwh = OpKernelType(proto::VarType::FP32, place, - DataLayout::kNCHW, LibraryType::kPlain); - - TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); - - EXPECT_TRUE(out.layout() == DataLayout::kNCHW); - EXPECT_TRUE(out.dims() == make_ddim({2, 2, 3, 1})); + auto place = paddle::platform::CPUPlace(); + paddle::framework::Tensor in = paddle::framework::Tensor(); + paddle::framework::Tensor out = paddle::framework::Tensor(); + in.mutable_data(paddle::framework::make_ddim({2, 3, 1, 2}), place); + in.set_layout(paddle::framework::DataLayout::kNHWC); + + auto kernel_nhwc = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP32, place, + paddle::framework::DataLayout::kNHWC, + paddle::framework::LibraryType::kPlain); + auto kernel_ncwh = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP32, place, + paddle::framework::DataLayout::kNCHW, + paddle::framework::LibraryType::kPlain); + + paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); + + EXPECT_TRUE(out.layout() == paddle::framework::DataLayout::kNCHW); + EXPECT_TRUE(out.dims() == paddle::framework::make_ddim({2, 2, 3, 1})); TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out); - EXPECT_TRUE(in.layout() == DataLayout::kNHWC); - EXPECT_TRUE(in.dims() == make_ddim({2, 3, 1, 2})); + EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC); + EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2})); } diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/paddle/fluid/framework/data_type_transform_test.cc index 6b9a8f5e28b372c45abfaa2c20575a55d9a9dd03..bbebea9f13fd37469a0e9b7be9719aca128f5687 100644 --- a/paddle/fluid/framework/data_type_transform_test.cc +++ b/paddle/fluid/framework/data_type_transform_test.cc @@ -17,43 +17,58 @@ limitations under the License. */ #include "gtest/gtest.h" TEST(DataTypeTransform, CPUTransform) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto place = CPUPlace(); - - auto kernel_fp16 = OpKernelType(proto::VarType::FP16, place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_int32 = OpKernelType(proto::VarType::INT32, place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_int64 = OpKernelType(proto::VarType::INT64, place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_bool = OpKernelType(proto::VarType::BOOL, place, - DataLayout::kAnyLayout, LibraryType::kPlain); + auto place = paddle::platform::CPUPlace(); + + auto kernel_fp16 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP16, place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_fp32 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP32, place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_fp64 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP64, place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_int32 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::INT32, place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_int64 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::INT64, place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_bool = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::BOOL, place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); // data type transform from float32 { - Tensor in; - Tensor out; + paddle::framework::Tensor in; + paddle::framework::Tensor out; - float* ptr = in.mutable_data(make_ddim({2, 3}), place); + float* ptr = + in.mutable_data(paddle::framework::make_ddim({2, 3}), place); int data_number = 2 * 3; for (int i = 0; i < data_number; ++i) { ptr[i] = i / 3; } - TransDataType(kernel_fp32, kernel_fp64, in, &out); + paddle::framework::TransDataType(kernel_fp32, kernel_fp64, in, &out); double* out_data_double = out.data(); for (int i = 0; i < data_number; ++i) { EXPECT_EQ(out_data_double[i], static_cast(i / 3)); } - TransDataType(kernel_fp32, kernel_int32, in, &out); + paddle::framework::TransDataType(kernel_fp32, kernel_int32, in, &out); int* out_data_int = out.data(); for (int i = 0; i < data_number; ++i) { EXPECT_EQ(out_data_int[i], static_cast(i / 3)); @@ -62,10 +77,11 @@ TEST(DataTypeTransform, CPUTransform) { // data type transform from/to float16 { - Tensor in; - Tensor out; + paddle::framework::Tensor in; + paddle::framework::Tensor out; - float16* ptr = in.mutable_data(make_ddim({2, 3}), place); + paddle::platform::float16* ptr = in.mutable_data( + paddle::framework::make_ddim({2, 3}), place); int data_number = 2 * 3; for (int i = 0; i < data_number; ++i) { @@ -73,94 +89,104 @@ TEST(DataTypeTransform, CPUTransform) { } // transform from float16 to other data types - TransDataType(kernel_fp16, kernel_fp32, in, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_fp32, in, &out); float* out_data_float = out.data(); for (int i = 0; i < data_number; ++i) { EXPECT_EQ(out_data_float[i], static_cast(ptr[i])); } - TransDataType(kernel_fp16, kernel_fp64, in, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_fp64, in, &out); double* out_data_double = out.data(); for (int i = 0; i < data_number; ++i) { EXPECT_EQ(out_data_double[i], static_cast(ptr[i])); } - TransDataType(kernel_fp16, kernel_int32, in, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_int32, in, &out); int* out_data_int = out.data(); for (int i = 0; i < data_number; ++i) { EXPECT_EQ(out_data_int[i], static_cast(ptr[i])); } - TransDataType(kernel_fp16, kernel_int64, in, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_int64, in, &out); int64_t* out_data_int64 = out.data(); for (int i = 0; i < data_number; ++i) { EXPECT_EQ(out_data_int64[i], static_cast(ptr[i])); } - TransDataType(kernel_fp16, kernel_bool, in, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_bool, in, &out); bool* out_data_bool = out.data(); for (int i = 0; i < data_number; ++i) { EXPECT_EQ(out_data_bool[i], static_cast(ptr[i])); } // transform float to float16 - float* in_data_float = in.mutable_data(make_ddim({2, 3}), place); + float* in_data_float = + in.mutable_data(paddle::framework::make_ddim({2, 3}), place); for (int i = 0; i < data_number; ++i) { in_data_float[i] = i; } - TransDataType(kernel_fp32, kernel_fp16, in, &out); - ptr = out.data(); + paddle::framework::TransDataType(kernel_fp32, kernel_fp16, in, &out); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_float[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_float[i]).x); } // transform double to float16 - double* in_data_double = in.mutable_data(make_ddim({2, 3}), place); + double* in_data_double = + in.mutable_data(paddle::framework::make_ddim({2, 3}), place); for (int i = 0; i < data_number; ++i) { in_data_double[i] = i; } - TransDataType(kernel_fp64, kernel_fp16, in, &out); - ptr = out.data(); + paddle::framework::TransDataType(kernel_fp64, kernel_fp16, in, &out); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_double[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_double[i]).x); } // transform int to float16 - int* in_data_int = in.mutable_data(make_ddim({2, 3}), place); + int* in_data_int = + in.mutable_data(paddle::framework::make_ddim({2, 3}), place); for (int i = 0; i < data_number; ++i) { in_data_int[i] = i; } - TransDataType(kernel_int32, kernel_fp16, in, &out); - ptr = out.data(); + paddle::framework::TransDataType(kernel_int32, kernel_fp16, in, &out); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_int[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_int[i]).x); } // transform int64 to float16 - int64_t* in_data_int64 = in.mutable_data(make_ddim({2, 3}), place); + int64_t* in_data_int64 = + in.mutable_data(paddle::framework::make_ddim({2, 3}), place); for (int i = 0; i < data_number; ++i) { in_data_int64[i] = i; } - TransDataType(kernel_int64, kernel_fp16, in, &out); - ptr = out.data(); + paddle::framework::TransDataType(kernel_int64, kernel_fp16, in, &out); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_int64[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_int64[i]).x); } // transform bool to float16 - bool* in_data_bool = in.mutable_data(make_ddim({2, 3}), place); + bool* in_data_bool = + in.mutable_data(paddle::framework::make_ddim({2, 3}), place); for (int i = 0; i < data_number; ++i) { in_data_bool[i] = i; } - TransDataType(kernel_bool, kernel_fp16, in, &out); - ptr = out.data(); + paddle::framework::TransDataType(kernel_bool, kernel_fp16, in, &out); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_bool[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_bool[i]).x); } } } diff --git a/paddle/fluid/framework/data_type_transform_test.cu b/paddle/fluid/framework/data_type_transform_test.cu index de389ddabcb86de0155757406a406e44086c5474..0874509a8797cd2ff1b1fcb347b4ef3b74a39047 100644 --- a/paddle/fluid/framework/data_type_transform_test.cu +++ b/paddle/fluid/framework/data_type_transform_test.cu @@ -18,42 +18,58 @@ limitations under the License. */ #include "gtest/gtest.h" TEST(DataTypeTransform, GPUTransform) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto cpu_place = CPUPlace(); - auto gpu_place = CUDAPlace(0); - CUDADeviceContext context(gpu_place); - - auto kernel_fp16 = OpKernelType(proto::VarType::FP16, gpu_place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_fp32 = OpKernelType(proto::VarType::FP32, gpu_place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_fp64 = OpKernelType(proto::VarType::FP64, gpu_place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_int32 = OpKernelType(proto::VarType::INT32, gpu_place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_int64 = OpKernelType(proto::VarType::INT64, gpu_place, - DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_bool = OpKernelType(proto::VarType::BOOL, gpu_place, - DataLayout::kAnyLayout, LibraryType::kPlain); + auto cpu_place = paddle::platform::CPUPlace(); + auto gpu_place = paddle::platform::CUDAPlace(0); + paddle::platform::CUDADeviceContext context(gpu_place); + + auto kernel_fp16 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP16, gpu_place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_fp32 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP32, gpu_place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_fp64 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP64, gpu_place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_int32 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::INT32, gpu_place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_int64 = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::INT64, gpu_place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); + + auto kernel_bool = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::BOOL, gpu_place, + paddle::framework::DataLayout::kAnyLayout, + paddle::framework::LibraryType::kPlain); // data type transform from float32 { - Tensor in; - Tensor in_gpu; - Tensor out_gpu; - Tensor out; + paddle::framework::Tensor in; + paddle::framework::Tensor in_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; - float* in_ptr = in.mutable_data(make_ddim({2, 3}), cpu_place); + float* in_ptr = + in.mutable_data(paddle::framework::make_ddim({2, 3}), cpu_place); float arr[6] = {0, 1, 2, 3, 4, 5}; int data_number = sizeof(arr) / sizeof(arr[0]); memcpy(in_ptr, arr, sizeof(arr)); - TensorCopy(in, gpu_place, context, &in_gpu); + paddle::framework::TensorCopy(in, gpu_place, context, &in_gpu); context.Wait(); - TransDataType(kernel_fp32, kernel_fp64, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp32, kernel_fp64, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); double* out_data_double = out.data(); @@ -61,8 +77,9 @@ TEST(DataTypeTransform, GPUTransform) { EXPECT_EQ(out_data_double[i], static_cast(arr[i])); } - TransDataType(kernel_fp32, kernel_int32, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp32, kernel_int32, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); int* out_data_int = out.data(); @@ -73,22 +90,27 @@ TEST(DataTypeTransform, GPUTransform) { // data type transform from/to float16 { - Tensor in; - Tensor in_gpu; - Tensor out_gpu; - Tensor out; - - float16* ptr = in.mutable_data(make_ddim({2, 3}), cpu_place); - float16 arr[6] = {float16(0), float16(1), float16(2), - float16(3), float16(4), float16(5)}; + paddle::framework::Tensor in; + paddle::framework::Tensor in_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + paddle::platform::float16* ptr = in.mutable_data( + paddle::framework::make_ddim({2, 3}), cpu_place); + paddle::platform::float16 arr[6] = { + paddle::platform::float16(0), paddle::platform::float16(1), + paddle::platform::float16(2), paddle::platform::float16(3), + paddle::platform::float16(4), paddle::platform::float16(5)}; + int data_number = sizeof(arr) / sizeof(arr[0]); memcpy(ptr, arr, sizeof(arr)); - TensorCopy(in, gpu_place, context, &in_gpu); + paddle::framework::TensorCopy(in, gpu_place, context, &in_gpu); context.Wait(); // transform from float16 to other data types - TransDataType(kernel_fp16, kernel_fp32, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_fp32, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); float* out_data_float = out.data(); @@ -96,8 +118,9 @@ TEST(DataTypeTransform, GPUTransform) { EXPECT_EQ(out_data_float[i], static_cast(ptr[i])); } - TransDataType(kernel_fp16, kernel_fp64, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_fp64, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); double* out_data_double = out.data(); @@ -105,8 +128,9 @@ TEST(DataTypeTransform, GPUTransform) { EXPECT_EQ(out_data_double[i], static_cast(ptr[i])); } - TransDataType(kernel_fp16, kernel_int32, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_int32, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); int* out_data_int = out.data(); @@ -114,8 +138,9 @@ TEST(DataTypeTransform, GPUTransform) { EXPECT_EQ(out_data_int[i], static_cast(ptr[i])); } - TransDataType(kernel_fp16, kernel_int64, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_int64, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); int64_t* out_data_int64 = out.data(); @@ -123,8 +148,9 @@ TEST(DataTypeTransform, GPUTransform) { EXPECT_EQ(out_data_int64[i], static_cast(ptr[i])); } - TransDataType(kernel_fp16, kernel_bool, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp16, kernel_bool, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); bool* out_data_bool = out.data(); @@ -133,90 +159,103 @@ TEST(DataTypeTransform, GPUTransform) { } // transform float to float16 - float* in_data_float = in.mutable_data(make_ddim({2, 3}), cpu_place); + float* in_data_float = + in.mutable_data(paddle::framework::make_ddim({2, 3}), cpu_place); for (int i = 0; i < data_number; ++i) { in_data_float[i] = i; } - TensorCopy(in, gpu_place, context, &in_gpu); + paddle::framework::TensorCopy(in, gpu_place, context, &in_gpu); context.Wait(); - TransDataType(kernel_fp32, kernel_fp16, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp32, kernel_fp16, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); - ptr = out.data(); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_float[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_float[i]).x); } // transform double to float16 - double* in_data_double = - in.mutable_data(make_ddim({2, 3}), cpu_place); + double* in_data_double = in.mutable_data( + paddle::framework::make_ddim({2, 3}), cpu_place); for (int i = 0; i < data_number; ++i) { in_data_double[i] = i; } - TensorCopy(in, gpu_place, context, &in_gpu); + paddle::framework::TensorCopy(in, gpu_place, context, &in_gpu); context.Wait(); - TransDataType(kernel_fp64, kernel_fp16, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_fp64, kernel_fp16, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); - ptr = out.data(); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_double[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_double[i]).x); } // transform int to float16 - int* in_data_int = in.mutable_data(make_ddim({2, 3}), cpu_place); + int* in_data_int = + in.mutable_data(paddle::framework::make_ddim({2, 3}), cpu_place); for (int i = 0; i < data_number; ++i) { in_data_int[i] = i; } - TensorCopy(in, gpu_place, context, &in_gpu); + paddle::framework::TensorCopy(in, gpu_place, context, &in_gpu); context.Wait(); - TransDataType(kernel_int32, kernel_fp16, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_int32, kernel_fp16, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); - ptr = out.data(); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_int[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_int[i]).x); } // transform int64 to float16 - int64_t* in_data_int64 = - in.mutable_data(make_ddim({2, 3}), cpu_place); + int64_t* in_data_int64 = in.mutable_data( + paddle::framework::make_ddim({2, 3}), cpu_place); for (int i = 0; i < data_number; ++i) { in_data_int64[i] = i; } - TensorCopy(in, gpu_place, context, &in_gpu); + paddle::framework::TensorCopy(in, gpu_place, context, &in_gpu); context.Wait(); - TransDataType(kernel_int64, kernel_fp16, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_int64, kernel_fp16, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); - ptr = out.data(); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_int64[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_int64[i]).x); } // transform bool to float16 - bool* in_data_bool = in.mutable_data(make_ddim({2, 3}), cpu_place); + bool* in_data_bool = + in.mutable_data(paddle::framework::make_ddim({2, 3}), cpu_place); for (int i = 0; i < data_number; ++i) { in_data_bool[i] = i; } - TensorCopy(in, gpu_place, context, &in_gpu); + paddle::framework::TensorCopy(in, gpu_place, context, &in_gpu); context.Wait(); - TransDataType(kernel_bool, kernel_fp16, in_gpu, &out_gpu); - TensorCopy(out_gpu, cpu_place, context, &out); + paddle::framework::TransDataType(kernel_bool, kernel_fp16, in_gpu, + &out_gpu); + paddle::framework::TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); - ptr = out.data(); + ptr = out.data(); for (int i = 0; i < data_number; ++i) { - EXPECT_EQ(ptr[i].x, static_cast(in_data_bool[i]).x); + EXPECT_EQ(ptr[i].x, + static_cast(in_data_bool[i]).x); } } } diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 6f990e28666829dd2f2fe6f07362188a77ae6468..96c181f983a33961e3d5fb8745740f2fdbb210de 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -8,27 +8,28 @@ cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope plac cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) +cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) + if(WITH_GPU) nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda) set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) - nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim dynload_cuda) + nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) else() set(multi_devices_graph_builder_deps) - cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim) + cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) endif() + +cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) +cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) + cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps}) + scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) -cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) - -cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory) -cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory) - cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory device_context broadcast_op_handle) cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 0bc3ee78d67e8548f093ff7086cf06a1ffb1c58b..33e02ab65a251a338225ee621ff14acbb0631992 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -44,9 +44,15 @@ void BroadcastOpHandle::RunImpl() { // &in_place; WaitInputVarGenerated(*in_var_handle); - auto *in_var = local_scopes_.at(in_var_handle->scope_idx_) - ->FindVar(in_var_handle->name_); + std::vector var_scopes; + for (auto *s : local_scopes_) { + var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); + } + + auto *in_var = + var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_); PADDLE_ENFORCE_NOT_NULL(in_var); + Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); for (auto *out : out_var_handles) { @@ -55,17 +61,16 @@ void BroadcastOpHandle::RunImpl() { } auto &out_p = out->place_; - auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); - + auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_); + PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(), "Places must be all on CPU or all on CUDA."); VariableVisitor::ShareDimsAndLoD(*in_var, out_var); - VariableVisitor::GetMutableTensor(out_var) - .Resize(in_tensor.dims()) - .mutable_data(out_p, in_tensor.type()); + VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, + in_tensor.type()); - auto dev_ctx = dev_ctxes_[out_p]; + auto dev_ctx = dev_ctxes_.at(out_p); RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { paddle::framework::TensorCopy( in_tensor, out_p, *(dev_ctx), diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index efc70515820d18fe61696fd697b0af0a0fef3834..3f2dcde3e9597287d72046dd4f8b07faab1ede25 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -30,6 +30,7 @@ const f::DDim kDims = {20, 20}; struct TestBroadcastOpHandle { std::vector> ctxs_; std::vector local_scopes_; + std::vector param_scopes_; Scope g_scope_; std::unique_ptr op_handle_; std::vector> vars_; @@ -72,11 +73,17 @@ struct TestBroadcastOpHandle { void InitBroadcastOp(size_t input_scope_idx) { for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); - local_scopes_[j]->Var("out"); + Scope& local_scope = local_scopes_.back()->NewScope(); + *local_scopes_.back() + ->Var(details::kLocalExecScopeName) + ->GetMutable() = &local_scope; + local_scope.Var("out"); + param_scopes_.emplace_back(&local_scope); } - local_scopes_[input_scope_idx]->Var("input"); + param_scopes_[input_scope_idx]->Var("input"); op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); + auto* in_var_handle = new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -105,7 +112,8 @@ struct TestBroadcastOpHandle { } void TestBroadcastLodTensor(size_t input_scope_idx) { - auto in_var = local_scopes_[input_scope_idx]->Var("input"); + auto in_var = param_scopes_[input_scope_idx]->FindVar("input"); + PADDLE_ENFORCE_NOT_NULL(in_var); auto in_lod_tensor = in_var->GetMutable(); in_lod_tensor->mutable_data(kDims, gpu_list_[input_scope_idx]); @@ -117,6 +125,7 @@ struct TestBroadcastOpHandle { paddle::framework::TensorFromVector( send_vector, *(ctxs_[input_scope_idx]), in_lod_tensor); in_lod_tensor->set_lod(lod); + in_lod_tensor->Resize(kDims); op_handle_->Run(false); @@ -124,7 +133,8 @@ struct TestBroadcastOpHandle { p::CPUPlace cpu_place; for (size_t j = 0; j < gpu_list_.size(); ++j) { - auto out_var = local_scopes_[j]->Var("out"); + auto out_var = param_scopes_[j]->FindVar("out"); + PADDLE_ENFORCE_NOT_NULL(out_var); auto out_tensor = out_var->Get(); PADDLE_ENFORCE_EQ(out_tensor.lod(), lod, "lod is not equal."); @@ -139,7 +149,8 @@ struct TestBroadcastOpHandle { } void TestBroadcastSelectedRows(size_t input_scope_idx) { - auto in_var = local_scopes_[input_scope_idx]->Var("input"); + auto in_var = param_scopes_[input_scope_idx]->FindVar("input"); + PADDLE_ENFORCE_NOT_NULL(in_var); auto in_selected_rows = in_var->GetMutable(); auto value = in_selected_rows->mutable_value(); value->mutable_data(kDims, gpu_list_[input_scope_idx]); @@ -162,7 +173,8 @@ struct TestBroadcastOpHandle { p::CPUPlace cpu_place; for (size_t j = 0; j < gpu_list_.size(); ++j) { - auto out_var = local_scopes_[j]->Var("out"); + auto out_var = param_scopes_[j]->FindVar("out"); + PADDLE_ENFORCE_NOT_NULL(out_var); auto& out_select_rows = out_var->Get(); auto rt = out_select_rows.value(); diff --git a/paddle/fluid/framework/details/cow_ptr.h b/paddle/fluid/framework/details/cow_ptr.h index 69bcea625288eba897e761a1d634f19c41dc0f79..21f75957be5f33f3dfc09c41fa9a1e1ca590f99e 100644 --- a/paddle/fluid/framework/details/cow_ptr.h +++ b/paddle/fluid/framework/details/cow_ptr.h @@ -14,7 +14,7 @@ #pragma once #include -#include +#include // NOLINT namespace paddle { namespace framework { @@ -23,7 +23,7 @@ namespace details { // Change it to thread safe flags if needed. class ThreadUnsafeOwnershipFlags { public: - ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {} + explicit ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {} ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete; ThreadUnsafeOwnershipFlags& operator=( diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 511fd941dc7270d79f9a565f03d233b6fdf41d37..3ed7723919fc3a547b15c28b846de758a8155e66 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -41,14 +41,19 @@ void GatherOpHandle::RunImpl() { out_var_handle = out_var_handles.front(); } + std::vector var_scopes; + for (auto *s : local_scopes_) { + var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); + } + auto in_0_handle = in_var_handles[0]; auto pre_in_var = - local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_); - auto pre_place = in_0_handle->place_; - + var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); + PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE(pre_in_var->IsType(), "Currently, gather_op only can gather SelectedRows."); + auto pre_place = in_0_handle->place_; PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(), "The place of input and output should be the same."); @@ -67,7 +72,7 @@ void GatherOpHandle::RunImpl() { PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), "Places must be all on CPU or all on CUDA."); auto *in_var = - local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_); + var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); auto &in_sr = in_var->Get(); PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(), @@ -86,7 +91,7 @@ void GatherOpHandle::RunImpl() { // write the output auto &out_place = out_var_handle->place_; auto out_scope_idx = out_var_handle->scope_idx_; - auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_); + auto out_var = var_scopes.at(out_scope_idx)->FindVar(out_var_handle->name_); auto out = out_var->GetMutable(); out->set_height(pre_in.height()); diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index 9481579f6c6f8272ab7b78a15d57c09a4d3245a4..3cce2cc1640b3866130126424ff8fef18b8befc6 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -29,6 +29,7 @@ const f::DDim kDims = {20, 20}; struct TestGatherOpHandle { std::vector> ctxs_; std::vector local_scopes_; + std::vector param_scopes_; Scope g_scope_; std::unique_ptr op_handle_; std::vector> vars_; @@ -71,9 +72,14 @@ struct TestGatherOpHandle { void InitGatherOp(size_t input_scope_idx) { for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); - local_scopes_[j]->Var("out"); + Scope& local_scope = local_scopes_.back()->NewScope(); + *local_scopes_.back() + ->Var(details::kLocalExecScopeName) + ->GetMutable() = &local_scope; + local_scope.Var("input"); + param_scopes_.emplace_back(&local_scope); } - local_scopes_[input_scope_idx]->Var("input"); + param_scopes_[input_scope_idx]->Var("out"); op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_)); // add input @@ -115,7 +121,8 @@ struct TestGatherOpHandle { for (size_t input_scope_idx = 0; input_scope_idx < gpu_list_.size(); ++input_scope_idx) { - auto in_var = local_scopes_[input_scope_idx]->Var("input"); + auto in_var = param_scopes_.at(input_scope_idx)->FindVar("input"); + PADDLE_ENFORCE_NOT_NULL(in_var); auto in_selected_rows = in_var->GetMutable(); auto value = in_selected_rows->mutable_value(); value->mutable_data(kDims, gpu_list_[input_scope_idx]); @@ -128,10 +135,11 @@ struct TestGatherOpHandle { value->Resize(kDims); } - auto out_var = local_scopes_[output_scope_idx]->Var("out"); + auto out_var = param_scopes_.at(output_scope_idx)->FindVar("out"); + PADDLE_ENFORCE_NOT_NULL(out_var); auto out_selected_rows = out_var->GetMutable(); - auto in_var = local_scopes_[output_scope_idx]->Var("input"); + auto in_var = param_scopes_.at(output_scope_idx)->FindVar("input"); auto in_selected_rows = in_var->GetMutable(); out_selected_rows->mutable_value()->ShareDataWith( @@ -155,7 +163,8 @@ struct TestGatherOpHandle { f::TensorCopy(rt, cpu_place, *(ctxs_[output_scope_idx]), &result_tensor); float* ct = result_tensor.data(); - for (int64_t j = 0; j < f::product(kDims); ++j) { + for (int64_t j = 0; + j < f::product(kDims) * static_cast(gpu_list_.size()); ++j) { ASSERT_NEAR(ct[j], send_vector[j % send_vector.size()], 1e-5); } } diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 28f9139987faa3dfee1e7733fb846a4d4efadc7b..b055bb48f608c9fd9cc671d175cb463d25dc489b 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -43,21 +43,21 @@ void NCCLAllReduceOpHandle::RunImpl() { int dtype = -1; size_t numel = 0; - std::vector lod_tensors; + std::vector lod_tensors; for (size_t i = 0; i < local_scopes_.size(); ++i) { auto *s = local_scopes_[i]; auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get(); auto &lod_tensor = local_scope.FindVar(var_name)->Get(); - lod_tensors.emplace_back(lod_tensor); + lod_tensors.emplace_back(&lod_tensor); } - if (platform::is_gpu_place(lod_tensors[0].place())) { + if (platform::is_gpu_place(lod_tensors[0]->place())) { std::vector> all_reduce_calls; for (size_t i = 0; i < local_scopes_.size(); ++i) { auto &p = places_[i]; - auto &lod_tensor = lod_tensors[i]; + auto &lod_tensor = *lod_tensors[i]; void *buffer = const_cast(lod_tensor.data()); if (dtype == -1) { @@ -93,7 +93,7 @@ void NCCLAllReduceOpHandle::RunImpl() { // Reduce All Tensor to trg in CPU ReduceLoDTensor func(lod_tensors, &trg); - VisitDataType(ToDataType(lod_tensors[0].type()), func); + VisitDataType(ToDataType(lod_tensors[0]->type()), func); for (size_t i = 0; i < local_scopes_.size(); ++i) { auto &scope = diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index d73604ad185a66ade0168f585d1951d0d7d4a5f9..06603db31e0092382c0cc05482a038473d647ef1 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#include +#include +#include #include "paddle/fluid/framework/grad_op_desc_maker.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_proto_maker.h" diff --git a/paddle/fluid/framework/details/reduce_and_gather.h b/paddle/fluid/framework/details/reduce_and_gather.h index 7957fba8a449f7dc05588fad335df0b45a34b575..2b95a284990da8f9b7c16d6e4221eb1ed061f74b 100644 --- a/paddle/fluid/framework/details/reduce_and_gather.h +++ b/paddle/fluid/framework/details/reduce_and_gather.h @@ -24,23 +24,23 @@ namespace framework { namespace details { struct ReduceLoDTensor { - const std::vector &src_tensors_; + const std::vector &src_tensors_; LoDTensor &dst_tensor_; - ReduceLoDTensor(const std::vector &src, LoDTensor *dst) + ReduceLoDTensor(const std::vector &src, LoDTensor *dst) : src_tensors_(src), dst_tensor_(*dst) {} template void operator()() const { PADDLE_ENFORCE(!src_tensors_.empty()); - auto &t0 = src_tensors_[0]; + auto &t0 = *src_tensors_[0]; PADDLE_ENFORCE_NE(t0.numel(), 0); dst_tensor_.Resize(t0.dims()); T *dst = dst_tensor_.mutable_data(platform::CPUPlace()); std::copy(t0.data(), t0.data() + t0.numel(), dst); for (size_t i = 1; i < src_tensors_.size(); ++i) { - auto &t = src_tensors_[i]; + auto &t = *src_tensors_[i]; PADDLE_ENFORCE_EQ(t.dims(), t0.dims()); PADDLE_ENFORCE_EQ(t.type(), t0.type()); std::transform(t.data(), t.data() + t.numel(), dst, dst, diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index c951de5dd5a7b66ce03c705e9bdcbe3f5c3e565d..409e8f72b841de03dcb50e62de447ae9895df2c0 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/details/reduce_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/reduce_and_gather.h" +#include "paddle/fluid/framework/details/variable_visitor.h" namespace paddle { namespace framework { @@ -21,85 +23,84 @@ namespace details { void ReduceOpHandle::RunImpl() { // the input and output may have dummy var. - std::vector in_var_handles = GetValidVarHandles(inputs_); - std::vector out_var_handles = GetValidVarHandles(outputs_); + auto in_var_handles = DynamicCast(inputs_); PADDLE_ENFORCE_EQ( in_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, - "The number of output should be one."); - // Wait input done, this Wait is asynchronous operation - WaitEvents(in_var_handles); + VarHandle *out_var_handle; + { + auto out_var_handles = DynamicCast(outputs_); + + PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, + "The number of output should be one."); + out_var_handle = out_var_handles.front(); + } - // check in the same place auto in_0_handle = in_var_handles[0]; - auto pre_place = in_0_handle->place_; + std::vector var_scopes; + for (auto *s : local_scopes_) { + var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); + } + + auto pre_in_var = + var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); + PADDLE_ENFORCE_NOT_NULL(pre_in_var); + + // Wait input done, this Wait is asynchronous operation + WaitInputVarGenerated(in_var_handles); + auto pre_place = in_0_handle->place_; std::vector in_places; + auto pre_in_tensor = VariableVisitor::GetMutableTensor(pre_in_var); for (auto *in_handle : in_var_handles) { auto in_p = in_handle->place_; PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), "Places must be all on CPU or all on CUDA."); in_places.emplace_back(in_p); - } - auto out_var = local_scopes_[out_var_handles[0]->scope_idx_]->FindVar( - out_var_handles[0]->name_); + auto in_var = + var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); + PADDLE_ENFORCE_NOT_NULL(in_var); - auto pre_in_var = - local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_); - - if (pre_in_var->IsType()) { - auto &pre_in = pre_in_var->Get(); - std::vector in_selected_rows; + auto in_tensor = VariableVisitor::GetMutableTensor(in_var); + PADDLE_ENFORCE_EQ(in_tensor.type(), pre_in_tensor.type(), + "The type of input is not consistent."); + } - for (auto *in_handle : in_var_handles) { - auto in_var = - local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_); - auto &in_sr = in_var->Get(); + auto out_var = + var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); + PADDLE_ENFORCE_NOT_NULL(out_var); - PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(), - "The type of input is not consistent."); + if (pre_in_var->IsType()) { + std::vector in_selected_rows = + GetInputValues(in_var_handles, var_scopes); - in_selected_rows.emplace_back(&in_sr); - } - auto trg = out_var->GetMutable(); GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, - out_var_handles[0]->place_, trg); + out_var_handle->place_, + out_var->GetMutable()); } else { - auto pre_in = pre_in_var->Get(); - std::vector lod_tensors; - - // can be refined - for (auto *in_handle : in_var_handles) { - auto in_var = - local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_); - auto &in_sr = in_var->Get(); - - PADDLE_ENFORCE_EQ(in_sr.type(), pre_in.type(), - "The type of input is not consistent."); - - lod_tensors.emplace_back(in_sr); - } - - auto trg = out_var->GetMutable(); - trg->Resize(pre_in.dims()); - trg->mutable_data(out_var_handles[0]->place_, pre_in.type()); + std::vector lod_tensors = + GetInputValues(in_var_handles, var_scopes); if (paddle::platform::is_cpu_place(pre_place)) { - ReduceLoDTensor func(lod_tensors, trg); - VisitDataType(ToDataType(lod_tensors[0].type()), func); + ReduceLoDTensor func(lod_tensors, + out_var->GetMutable()); + VisitDataType(ToDataType(lod_tensors[0]->type()), func); } else if (paddle::platform::is_gpu_place(pre_place)) { #ifdef PADDLE_WITH_CUDA - auto out_p = out_var_handles[0]->place_; - int root = boost::get(out_p).device; + auto pre_in = pre_in_var->Get(); + VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var); + VariableVisitor::GetMutableTensor(out_var).mutable_data( + out_var_handle->place_, pre_in.type()); + auto out_p = out_var_handle->place_; + int root = boost::get(out_p).device; std::vector> all_reduce_calls; - for (size_t i = 0; i < local_scopes_.size(); ++i) { + for (size_t i = 0; i < var_scopes.size(); ++i) { auto &p = in_places[i]; - auto &lod_tensor = lod_tensors[i]; + auto &lod_tensor = *lod_tensors[i]; int dev_id = boost::get(p).device; auto &nccl_ctx = nccl_ctxs_->at(dev_id); @@ -109,14 +110,16 @@ void ReduceOpHandle::RunImpl() { void *buffer = const_cast(lod_tensor.data()); void *recvbuffer = nullptr; if (root == dev_id) { - recvbuffer = trg->mutable_data(out_var_handles[0]->place_); + recvbuffer = + out_var->GetMutable()->mutable_data( + out_var_handle->place_); } + int type = platform::ToNCCLDataType(lod_tensor.type()); all_reduce_calls.emplace_back([=] { PADDLE_ENFORCE(platform::dynload::ncclReduce( buffer, recvbuffer, static_cast(lod_tensor.numel()), - platform::ToNCCLDataType(lod_tensor.type()), ncclSum, root, comm, - stream)); + static_cast(type), ncclSum, root, comm, stream)); }); } @@ -135,26 +138,31 @@ void ReduceOpHandle::RunImpl() { } } -void ReduceOpHandle::WaitEvents( - const std::vector &in_var_handles) { - if (in_var_handles[0]->generated_op_) { - for (auto *in : in_var_handles) { - in_var_handles[0]->generated_op_->Wait(dev_ctxes_[in->place_]); - } +template +std::vector ReduceOpHandle::GetInputValues( + const std::vector &in_var_handles, + const std::vector &var_scopes) const { + std::vector in_selected_rows; + for (auto *in_handle : in_var_handles) { + auto &in_sr = var_scopes.at(in_handle->scope_idx_) + ->FindVar(in_handle->name_) + ->Get(); + in_selected_rows.emplace_back(&in_sr); } + return in_selected_rows; } -std::vector ReduceOpHandle::GetValidVarHandles( - const std::vector &inputs) { - std::vector in_var_handles; - for (auto *in : inputs) { - auto *in_handle = dynamic_cast(in); - if (in_handle) { - in_var_handles.push_back(in_handle); +void ReduceOpHandle::WaitInputVarGenerated( + const std::vector &in_var_handles) { + for (auto *in : in_var_handles) { + if (in->generated_op_) { + for (auto pair : dev_ctxes_) { + in->generated_op_->Wait(pair.second); + } } } - return in_var_handles; } + std::string ReduceOpHandle::Name() const { return "reduce"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 7b36ce4a7bceaeb93ceb03730b2d54d0f36fed3d..9746b3bdbde14d24a83a27a593c5f1ebfec201ff 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -59,10 +59,13 @@ struct ReduceOpHandle : public OpHandleBase { protected: void RunImpl() override; - std::vector GetValidVarHandles( - const std::vector &inputs); - void WaitEvents(const std::vector &in_var_handles); + void WaitInputVarGenerated(const std::vector &in_var_handles); + + template + std::vector GetInputValues( + const std::vector &in_var_handles, + const std::vector &var_scopes) const; }; } // namespace details diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index a57c7882ed77ae8cfbf7e284058d94935975828b..c17aabee53680fba10eac289cf8f8bd5f7d419e8 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -14,7 +14,6 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "gtest/gtest.h" - #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -30,6 +29,7 @@ struct TestReduceOpHandle { bool use_gpu_; Scope g_scope_; std::vector local_scopes_; + std::vector param_scopes_; std::unique_ptr op_handle_; std::vector> vars_; std::vector gpu_list_; @@ -83,12 +83,18 @@ struct TestReduceOpHandle { } } - void InitReduceOp(size_t input_scope_idx) { + void InitReduceOp(size_t out_scope_idx) { + // init scope for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); - local_scopes_[j]->Var("out"); + Scope &local_scope = local_scopes_.back()->NewScope(); + *local_scopes_.back() + ->Var(details::kLocalExecScopeName) + ->GetMutable() = &local_scope; + local_scope.Var("input"); + param_scopes_.emplace_back(&local_scope); } - local_scopes_[input_scope_idx]->Var("input"); + param_scopes_[out_scope_idx]->Var("out"); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA @@ -106,6 +112,7 @@ struct TestReduceOpHandle { #endif } + // init op handle // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { if (!use_gpu_) { @@ -126,7 +133,7 @@ struct TestReduceOpHandle { // add output auto *out_var_handle = - new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]); + new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); @@ -148,7 +155,8 @@ struct TestReduceOpHandle { for (size_t input_scope_idx = 0; input_scope_idx < gpu_list_.size(); ++input_scope_idx) { - auto in_var = local_scopes_[input_scope_idx]->Var("input"); + auto in_var = param_scopes_[input_scope_idx]->FindVar("input"); + PADDLE_ENFORCE_NOT_NULL(in_var); auto in_selected_rows = in_var->GetMutable(); auto value = in_selected_rows->mutable_value(); value->mutable_data(kDims, gpu_list_[input_scope_idx]); @@ -161,10 +169,11 @@ struct TestReduceOpHandle { value->Resize(kDims); } - auto out_var = local_scopes_[output_scope_idx]->Var("out"); + auto out_var = param_scopes_[output_scope_idx]->FindVar("out"); + PADDLE_ENFORCE_NOT_NULL(out_var); auto out_selected_rows = out_var->GetMutable(); - auto in_var = local_scopes_[output_scope_idx]->Var("input"); + auto in_var = param_scopes_[output_scope_idx]->FindVar("input"); auto in_selected_rows = in_var->GetMutable(); out_selected_rows->mutable_value()->ShareDataWith( @@ -202,7 +211,8 @@ struct TestReduceOpHandle { for (size_t input_scope_idx = 0; input_scope_idx < gpu_list_.size(); ++input_scope_idx) { - auto in_var = local_scopes_[input_scope_idx]->Var("input"); + auto in_var = param_scopes_[input_scope_idx]->FindVar("input"); + PADDLE_ENFORCE_NOT_NULL(in_var); auto in_lod_tensor = in_var->GetMutable(); in_lod_tensor->mutable_data(kDims, gpu_list_[input_scope_idx]); in_lod_tensor->set_lod(lod); @@ -211,10 +221,11 @@ struct TestReduceOpHandle { send_vector, *(ctxs_[input_scope_idx]), in_lod_tensor); } - auto out_var = local_scopes_[output_scope_idx]->Var("out"); + auto out_var = param_scopes_[output_scope_idx]->FindVar("out"); + PADDLE_ENFORCE_NOT_NULL(out_var); auto out_lodtensor = out_var->GetMutable(); - auto in_var = local_scopes_[output_scope_idx]->Var("input"); + auto in_var = param_scopes_[output_scope_idx]->FindVar("input"); auto in_lodtensor = in_var->Get(); out_lodtensor->ShareDataWith(in_lodtensor); @@ -239,34 +250,34 @@ struct TestReduceOpHandle { TEST(ReduceTester, TestCPUReduceTestSelectedRows) { TestReduceOpHandle test_op; - size_t input_scope_idx = 0; + size_t out_scope_idx = 0; test_op.InitCtxOnGpu(false); - test_op.InitReduceOp(input_scope_idx); - test_op.TestReduceSelectedRows(input_scope_idx); + test_op.InitReduceOp(out_scope_idx); + test_op.TestReduceSelectedRows(out_scope_idx); } TEST(ReduceTester, TestCPUReduceTestLodTensor) { TestReduceOpHandle test_op; - size_t input_scope_idx = 0; + size_t out_scope_idx = 0; test_op.InitCtxOnGpu(false); - test_op.InitReduceOp(input_scope_idx); - test_op.TestReduceLodTensors(input_scope_idx); + test_op.InitReduceOp(out_scope_idx); + test_op.TestReduceLodTensors(out_scope_idx); } #ifdef PADDLE_WITH_CUDA TEST(ReduceTester, TestGPUReduceTestSelectedRows) { TestReduceOpHandle test_op; - size_t input_scope_idx = 0; + size_t out_scope_idx = 0; test_op.InitCtxOnGpu(true); - test_op.InitReduceOp(input_scope_idx); - test_op.TestReduceSelectedRows(input_scope_idx); + test_op.InitReduceOp(out_scope_idx); + test_op.TestReduceSelectedRows(out_scope_idx); } TEST(ReduceTester, TestGPUReduceTestLodTensor) { TestReduceOpHandle test_op; - size_t input_scope_idx = 0; + size_t out_scope_idx = 0; test_op.InitCtxOnGpu(true); - test_op.InitReduceOp(input_scope_idx); - test_op.TestReduceLodTensors(input_scope_idx); + test_op.InitReduceOp(out_scope_idx); + test_op.TestReduceLodTensors(out_scope_idx); } #endif diff --git a/paddle/fluid/framework/op_registry_test.cc b/paddle/fluid/framework/op_registry_test.cc index 0d791c8583537d410b838c1662755938353052a9..6dc4cf261bad3c004aa53fba5502fe166e3a47f7 100644 --- a/paddle/fluid/framework/op_registry_test.cc +++ b/paddle/fluid/framework/op_registry_test.cc @@ -202,8 +202,9 @@ class CosineOpComplete : public paddle::framework::CosineOp { }; TEST(OperatorRegistrar, Test) { - using namespace paddle::framework; - OperatorRegistrar reg("cos"); + paddle::framework::OperatorRegistrar< + CosineOpComplete, paddle::framework::CosineOpProtoAndCheckerMaker> + reg("cos"); } namespace paddle { diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 25f622b725277ac9bcca4622902162f3edf147e8..1bf8c81469bb4afdd00921cfa0acf6089dedbbaa 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -226,10 +226,8 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, // test with multi inputs TEST(OpKernel, multi_inputs) { - using namespace paddle::framework; - paddle::framework::InitDevices(true); - proto::OpDesc op_desc; + paddle::framework::proto::OpDesc op_desc; op_desc.set_type("op_multi_inputs_with_kernel"); BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); @@ -243,12 +241,12 @@ TEST(OpKernel, multi_inputs) { paddle::platform::CPUPlace cpu_place; paddle::framework::Scope scope; - scope.Var("x0")->GetMutable(); - scope.Var("x1")->GetMutable(); - scope.Var("x2")->GetMutable(); - scope.Var("k0")->GetMutable(); - scope.Var("y0")->GetMutable(); - scope.Var("y1")->GetMutable(); + scope.Var("x0")->GetMutable(); + scope.Var("x1")->GetMutable(); + scope.Var("x2")->GetMutable(); + scope.Var("k0")->GetMutable(); + scope.Var("y0")->GetMutable(); + scope.Var("y1")->GetMutable(); auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_place); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 77d17fbbccca0292e21acd5e8fa90448527b95c0..16694bcf76486a9603c41dc19a58dd0a7cb2b719 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -27,10 +27,14 @@ BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) { return blocks_.back().get(); } -proto::ProgramDesc *ProgramDesc::Proto() { +void ProgramDesc::Flush() { for (auto &block : blocks_) { block->Flush(); } +} + +proto::ProgramDesc *ProgramDesc::Proto() { + Flush(); return &desc_; } diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 4288081be72c44c0fc3584b50c41a270eac9e204..65fa0a0cfd5ba6d9b8765cee1309e118cb74348a 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -51,6 +51,8 @@ class ProgramDesc { size_t Size() const { return blocks_.size(); } + void Flush(); + proto::ProgramDesc *Proto(); // The output variable of feed_op is referenced as feed_target. diff --git a/paddle/fluid/framework/threadpool_test.cc b/paddle/fluid/framework/threadpool_test.cc index 4da83d630a5632233ddff6f08174dcabc1c696f8..27a4ffd4fcbf293a3dea1744b29384d0bee0c137 100644 --- a/paddle/fluid/framework/threadpool_test.cc +++ b/paddle/fluid/framework/threadpool_test.cc @@ -15,14 +15,14 @@ limitations under the License. */ #include #include -#include "threadpool.h" +#include "paddle/fluid/framework/threadpool.h" namespace framework = paddle::framework; -void do_sum(framework::ThreadPool* pool, std::atomic& sum, int cnt) { +void do_sum(framework::ThreadPool* pool, std::atomic* sum, int cnt) { std::vector> fs; for (int i = 0; i < cnt; ++i) { - fs.push_back(framework::Async([&sum]() { sum.fetch_add(1); })); + fs.push_back(framework::Async([sum]() { sum->fetch_add(1); })); } } @@ -46,7 +46,7 @@ TEST(ThreadPool, ConcurrentRun) { int n = 50; // sum = (n * (n + 1)) / 2 for (int i = 1; i <= n; ++i) { - std::thread t(do_sum, pool, std::ref(sum), i); + std::thread t(do_sum, pool, &sum, i); threads.push_back(std::move(t)); } for (auto& t : threads) { diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 3b58019db6e55fa8198d2f77731095c6cf356266..78d2f16746cf478c4424df929bd1f62b08f8a67c 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/inference/io.h" +#include #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/feed_fetch_type.h" @@ -27,14 +28,14 @@ namespace inference { // linking the inference shared library. void Init(bool init_p2p) { framework::InitDevices(init_p2p); } -void ReadBinaryFile(const std::string& filename, std::string& contents) { +void ReadBinaryFile(const std::string& filename, std::string* contents) { std::ifstream fin(filename, std::ios::in | std::ios::binary); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", filename); fin.seekg(0, std::ios::end); - contents.clear(); - contents.resize(fin.tellg()); + contents->clear(); + contents->resize(fin.tellg()); fin.seekg(0, std::ios::beg); - fin.read(&contents[0], contents.size()); + fin.read(&(contents->at(0)), contents->size()); fin.close(); } @@ -47,7 +48,7 @@ bool IsPersistable(const framework::VarDesc* var) { return false; } -void LoadPersistables(framework::Executor& executor, framework::Scope& scope, +void LoadPersistables(framework::Executor* executor, framework::Scope* scope, const framework::ProgramDesc& main_program, const std::string& dirname, const std::string& param_filename) { @@ -92,18 +93,18 @@ void LoadPersistables(framework::Executor& executor, framework::Scope& scope, op->CheckAttrs(); } - executor.Run(*load_program, &scope, 0, true, true); + executor->Run(*load_program, scope, 0, true, true); delete load_program; } -std::unique_ptr Load(framework::Executor& executor, - framework::Scope& scope, +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, const std::string& dirname) { std::string model_filename = dirname + "/__model__"; std::string program_desc_str; VLOG(3) << "loading model from " << model_filename; - ReadBinaryFile(model_filename, program_desc_str); + ReadBinaryFile(model_filename, &program_desc_str); std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); @@ -113,11 +114,11 @@ std::unique_ptr Load(framework::Executor& executor, } std::unique_ptr Load( - framework::Executor& executor, framework::Scope& scope, + framework::Executor* executor, framework::Scope* scope, const std::string& prog_filename, const std::string& param_filename) { std::string model_filename = prog_filename; std::string program_desc_str; - ReadBinaryFile(model_filename, program_desc_str); + ReadBinaryFile(model_filename, &program_desc_str); std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h index 756c936b33ad55e2994542b171b945e248ba2e21..ba3e45099ae7c1626bf11d9527d4fa4c7f772fec 100644 --- a/paddle/fluid/inference/io.h +++ b/paddle/fluid/inference/io.h @@ -27,17 +27,17 @@ namespace inference { void Init(bool init_p2p); -void LoadPersistables(framework::Executor& executor, framework::Scope& scope, +void LoadPersistables(framework::Executor* executor, framework::Scope* scope, const framework::ProgramDesc& main_program, const std::string& dirname, const std::string& param_filename); -std::unique_ptr Load(framework::Executor& executor, - framework::Scope& scope, +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, const std::string& dirname); -std::unique_ptr Load(framework::Executor& executor, - framework::Scope& scope, +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, const std::string& prog_filename, const std::string& param_filename); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index c3a8d0889c6a6dd9591837ccc523da56f8d13661..117472599f7c4874ab05e29c6ecb46fd61d0db9c 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -133,12 +133,12 @@ void TestInference(const std::string& dirname, std::string prog_filename = "__model_combined__"; std::string param_filename = "__params_combined__"; inference_program = paddle::inference::Load( - executor, *scope, dirname + "/" + prog_filename, + &executor, scope, dirname + "/" + prog_filename, dirname + "/" + param_filename); } else { // Parameters are saved in separate files sited in the specified // `dirname`. - inference_program = paddle::inference::Load(executor, *scope, dirname); + inference_program = paddle::inference::Load(&executor, scope, dirname); } } // Disable the profiler and print the timing information diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d3bdd67ba8bd9ff699b686fd9e51fd1c15fa3100..2d38663df9b06910145f461a8e5a9ad2cc7d1a84 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -163,7 +163,12 @@ function(op_library TARGET) # pybind USE_OP if (${pybind_flag} EQUAL 0) + # NOTE(*): activation use macro to regist the kernels, set use_op manually. + if(${TARGET} STREQUAL "activation") + file(APPEND ${pybind_file} "USE_OP(relu);\n") + else() file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") + endif() endif() endfunction() diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 56451f8f147adb65cf64e4a5948eb626b87749b7..549629ffd664825c40b8cd89811f31c6ab390fd3 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -13,11 +13,48 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" +#include #include "paddle/fluid/operators/mkldnn_activation_op.h" namespace paddle { namespace operators { +#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ + class OP_NAME##OpMaker \ + : public ::paddle::framework::OpProtoAndCheckerMaker { \ + public: \ + OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker) \ + : ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \ + AddInput("X", "Input of " #OP_NAME "operator"); \ + AddOutput("Out", "Output of" #OP_NAME "operator"); \ + AddAttr("use_mkldnn", \ + "(bool, default false) Only used in mkldnn kernel") \ + .SetDefault(false); \ + AddComment(#OP_COMMENT); \ + } \ + } + +#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ + class OP_NAME##GradMaker \ + : public ::paddle::framework::SingleGradOpDescMaker { \ + public: \ + using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \ + \ + protected: \ + std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \ + auto *op = new ::paddle::framework::OpDesc(); \ + op->SetType(#KERNEL_TYPE "_grad"); \ + op->SetInput("Out", Output("Out")); \ + op->SetInput(::paddle::framework::GradVarName("Out"), \ + OutputGrad("Out")); \ + \ + op->SetAttrMap(Attrs()); \ + \ + op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \ + return std::unique_ptr<::paddle::framework::OpDesc>(op); \ + } \ + } + class ActivationOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -37,346 +74,190 @@ class ActivationOpGrad : public framework::OperatorWithKernel { } }; -class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Sigmoid operator"); - AddOutput("Out", "Output of Sigmoid operator"); - AddComment(R"DOC( +constexpr char SigmoidDoc[] = R"DOC( Sigmoid Activation Operator $$out = \frac{1}{1 + e^{-x}}$$ -)DOC"); - } -}; +)DOC"; -class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { - public: - LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of LogSigmoid operator"); - AddOutput("Out", "Output of LogSigmoid operator"); - AddComment(R"DOC( +constexpr char LogSigmoidDoc[] = R"DOC( Logsigmoid Activation Operator $$out = \log \frac{1}{1 + e^{-x}}$$ -)DOC"); - } -}; +)DOC"; -class ExpOpMaker : public framework::OpProtoAndCheckerMaker { - public: - ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Exp operator"); - AddOutput("Out", "Output of Exp operator"); - AddComment(R"DOC( +constexpr char ExpDoc[] = R"DOC( Exp Activation Operator. $out = e^x$ -)DOC"); - } -}; +)DOC"; -class ReluOpMaker : public framework::OpProtoAndCheckerMaker { - public: - ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Relu operator"); - AddOutput("Out", "Output of Relu operator"); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddComment(R"DOC( +constexpr char ReluDoc[] = R"DOC( Relu Activation Operator. $out = \max(x, 0)$ -)DOC"); - } -}; - -class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { - public: - LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of LeakyRelu operator"); - AddOutput("Out", "Output of LeakyRelu operator"); - AddAttr("alpha", "The small negative slope").SetDefault(0.02f); - AddComment(R"DOC( -LeakyRelu Activation Operator. - -$out = \max(x, \alpha * x)$ - -)DOC"); - } -}; - -class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Softshrink operator"); - AddOutput("Out", "Output of Softshrink operator"); - AddAttr("lambda", "non-negative offset").SetDefault(0.5f); - AddComment(R"DOC( -Softshrink Activation Operator. - -$$ -out = \begin{cases} - x - \lambda, \text{if } x > \lambda \\ - x + \lambda, \text{if } x < -\lambda \\ - 0, \text{otherwise} - \end{cases} -$$ +)DOC"; -)DOC"); - } -}; - -class TanhOpMaker : public framework::OpProtoAndCheckerMaker { - public: - TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Tanh operator"); - AddOutput("Out", "Output of Tanh operator"); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddComment(R"DOC( +constexpr char TanhDoc[] = R"DOC( Tanh Activation Operator. $$out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ -)DOC"); - } -}; +)DOC"; -class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { - public: - TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of TanhShrink operator"); - AddOutput("Out", "Output of TanhShrink operator"); - AddComment(R"DOC( +constexpr char TanhShrinkDoc[] = R"DOC( TanhShrink Activation Operator. $$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ -)DOC"); - } -}; - -class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { - public: - HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of HardShrink operator"); - AddOutput("Out", "Output of HardShrink operator"); - AddAttr("threshold", "The value of threshold for HardShrink") - .SetDefault(0.5f); - AddComment(R"DOC( -HardShrink Activation Operator. +)DOC"; -$$ -out = \begin{cases} - x, \text{if } x > \lambda \\ - x, \text{if } x < -\lambda \\ - 0, \text{otherwise} - \end{cases} -$$ - -)DOC"); - } -}; - -class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Sqrt operator"); - AddOutput("Out", "Output of Sqrt operator"); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddComment(R"DOC( +constexpr char SqrtDoc[] = R"DOC( Sqrt Activation Operator. $out = \sqrt{x}$ -)DOC"); - } -}; +)DOC"; -class AbsOpMaker : public framework::OpProtoAndCheckerMaker { - public: - AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Abs operator"); - AddOutput("Out", "Output of Abs operator"); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddComment(R"DOC( +constexpr char AbsDoc[] = R"DOC( Abs Activation Operator. $out = |x|$ -)DOC"); - } -}; +)DOC"; -class CeilOpMaker : public framework::OpProtoAndCheckerMaker { - public: - CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Ceil operator"); - AddOutput("Out", "Output of Ceil operator"); - AddComment(R"DOC( +constexpr char CeilDoc[] = R"DOC( Ceil Activation Operator. $out = ceil(x)$ -)DOC"); - } -}; +)DOC"; -class FloorOpMaker : public framework::OpProtoAndCheckerMaker { - public: - FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Floor operator"); - AddOutput("Out", "Output of Floor operator"); - AddComment(R"DOC( +constexpr char FloorDoc[] = R"DOC( Floor Activation Operator. $out = floor(x)$ -)DOC"); - } -}; +)DOC"; -class CosOpMaker : public framework::OpProtoAndCheckerMaker { - public: - CosOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Cosine operator"); - AddOutput("Out", "Output of Cosine operator"); - AddComment(R"DOC( +constexpr char CosDoc[] = R"DOC( Cosine Activation Operator. $out = cos(x)$ -)DOC"); - } -}; +)DOC"; -class SinOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SinOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Sine operator"); - AddOutput("Out", "Output of Sine operator"); - AddComment(R"DOC( +constexpr char SinDoc[] = R"DOC( Sine Activation Operator. $out = sin(x)$ -)DOC"); - } -}; +)DOC"; -class RoundOpMaker : public framework::OpProtoAndCheckerMaker { - public: - RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Round operator"); - AddOutput("Out", "Output of Round operator"); - AddComment(R"DOC( +constexpr char RoundDoc[] = R"DOC( Round Activation Operator. $out = [x]$ -)DOC"); - } -}; +)DOC"; -class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker { - public: - ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Reciprocal operator"); - AddOutput("Out", "Output of Reciprocal operator"); - AddComment(R"DOC( +constexpr char ReciprocalDoc[] = R"DOC( Reciprocal Activation Operator. $$out = \frac{1}{x}$$ -)DOC"); - } -}; +)DOC"; -class LogOpMaker : public framework::OpProtoAndCheckerMaker { - public: - LogOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Log operator"); - AddOutput("Out", "Output of Log operator"); - AddComment(R"DOC( +constexpr char LogDoc[] = R"DOC( Log Activation Operator. $out = \ln(x)$ Natural logarithm of x. -)DOC"); - } -}; +)DOC"; + +constexpr char SquareDoc[] = R"DOC( +Square Activation Operator. + +$out = x^2$ -class SquareOpMaker : public framework::OpProtoAndCheckerMaker { +)DOC"; + +constexpr char SoftplusDoc[] = R"DOC( +Softplus Activation Operator. + +$out = \ln(1 + e^{x})$ + +)DOC"; + +constexpr char SoftsignDoc[] = R"DOC( +Softsign Activation Operator. + +$$out = \frac{x}{1 + |x|}$$ + +)DOC"; + +class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { public: - SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker) + LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Square operator"); - AddOutput("Out", "Output of Square operator"); + AddInput("X", "Input of LeakyRelu operator"); + AddOutput("Out", "Output of LeakyRelu operator"); + AddAttr("alpha", "The small negative slope").SetDefault(0.02f); AddComment(R"DOC( -Square Activation Operator. +LeakyRelu Activation Operator. -$out = x^2$ +$out = \max(x, \alpha * x)$ )DOC"); } }; -class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { +class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { public: - SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker) + SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Softplus operator"); - AddOutput("Out", "Output of Softplus operator"); + AddInput("X", "Input of Softshrink operator"); + AddOutput("Out", "Output of Softshrink operator"); + AddAttr("lambda", "non-negative offset").SetDefault(0.5f); AddComment(R"DOC( -Softplus Activation Operator. +Softshrink Activation Operator. -$out = \ln(1 + e^{x})$ +$$ +out = \begin{cases} + x - \lambda, \text{if } x > \lambda \\ + x + \lambda, \text{if } x < -\lambda \\ + 0, \text{otherwise} + \end{cases} +$$ )DOC"); } }; -class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { +class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { public: - SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker) + HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of Softsign operator"); - AddOutput("Out", "Output of Softsign operator"); + AddInput("X", "Input of HardShrink operator"); + AddOutput("Out", "Output of HardShrink operator"); + AddAttr("threshold", "The value of threshold for HardShrink") + .SetDefault(0.5f); AddComment(R"DOC( -Softsign Activation Operator. +HardShrink Activation Operator. -$$out = \frac{x}{1 + |x|}$$ +$$ +out = \begin{cases} + x, \text{if } x > \lambda \\ + x, \text{if } x < -\lambda \\ + 0, \text{otherwise} + \end{cases} +$$ )DOC"); } @@ -553,131 +434,86 @@ $$out = \frac{x}{1 + e^{- \beta x}}$$ } }; +REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc); +REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc); +REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc); +REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc); +REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc); +REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc); +REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc); +REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc); +REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc); +REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc); +REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc); +REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc); +REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc); +REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc); +REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc); +REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc); +REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc); +REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); + +REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt); +REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6); +REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal); +REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid); } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sigmoid_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(logsigmoid_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(exp, ops::ActivationOp, ops::ExpOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(relu_grad, ops::ActivationWithMKLDNNOpGrad); - -REGISTER_OPERATOR(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(tanh_grad, ops::ActivationWithMKLDNNOpGrad); - -REGISTER_OPERATOR(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(tanh_shrink_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(softshrink_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sqrt_grad, ops::ActivationWithMKLDNNOpGrad); - -REGISTER_OPERATOR(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(abs_grad, ops::ActivationWithMKLDNNOpGrad); - -REGISTER_OPERATOR(ceil, ops::ActivationOp, ops::CeilOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(ceil_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(floor, ops::ActivationOp, ops::FloorOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(floor_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(cos, ops::ActivationOp, ops::CosOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(cos_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(sin, ops::ActivationOp, ops::SinOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sin_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(round, ops::ActivationOp, ops::RoundOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(round_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(reciprocal_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(log, ops::ActivationOp, ops::LogOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(log_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(square, ops::ActivationOp, ops::SquareOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(softplus, ops::ActivationOp, ops::SoftplusOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(softplus_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(softsign, ops::ActivationOp, ops::SoftsignOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(softsign_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(brelu, ops::ActivationOp, ops::BReluOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(brelu_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(soft_relu_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(elu, ops::ActivationOp, ops::ELUOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(relu6, ops::ActivationOp, ops::Relu6OpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(relu6_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(pow, ops::ActivationOp, ops::PowOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(pow_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(stanh, ops::ActivationOp, ops::STanhOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(stanh_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(hard_shrink_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(thresholded_relu, ops::ActivationOp, - ops::ThresholdedReluOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(thresholded_relu_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(hard_sigmoid_grad, ops::ActivationOpGrad); - -REGISTER_OPERATOR(swish, ops::ActivationOp, ops::SwishOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(swish_grad, ops::ActivationOpGrad); +#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \ + __macro(Sigmoid, sigmoid); \ + __macro(Relu, relu); \ + __macro(Exp, exp); \ + __macro(Tanh, tanh); \ + __macro(Ceil, ceil); \ + __macro(Floor, floor); \ + __macro(Sqrt, sqrt); \ + __macro(SoftRelu, soft_relu); \ + __macro(Relu6, relu6); \ + __macro(Reciprocal, reciprocal); \ + __macro(HardSigmoid, hard_sigmoid); + +#define FOR_EACH_OP_FUNCTOR(__macro) \ + __macro(LogSigmoid, logsigmoid); \ + __macro(SoftShrink, softshrink); \ + __macro(Abs, abs); \ + __macro(Cos, cos); \ + __macro(Sin, sin); \ + __macro(Round, round); \ + __macro(Log, log); \ + __macro(Square, square); \ + __macro(BRelu, brelu); \ + __macro(Pow, pow); \ + __macro(STanh, stanh); \ + __macro(Softplus, softplus); \ + __macro(Softsign, softsign); \ + __macro(LeakyRelu, leaky_relu); \ + __macro(TanhShrink, tanh_shrink); \ + __macro(ELU, elu); \ + __macro(HardShrink, hard_shrink); \ + __macro(Swish, swish); \ + __macro(ThresholdedRelu, thresholded_relu); + +#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ + REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ + ::paddle::operators::OP_NAME##OpMaker, \ + ::paddle::operators::OP_NAME##GradMaker); \ + REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) + +#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ + REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ + ::paddle::operators::OP_NAME##OpMaker, \ + ::paddle::framework::DefaultGradOpDescMaker); \ + REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ @@ -692,4 +528,6 @@ REGISTER_OPERATOR(swish_grad, ops::ActivationOpGrad); ops::ActivationGradKernel>); +FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); +FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 4f745553c14fc1391bc65d4f7e4f9bd3b5a881c2..27487b396ccf63d962defa6b270063ccb409164e 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 43856780bf9357281ac4af2968950da15426e5c8..912415192659dc004f54a76e9cd1a20581d512a6 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -10,6 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include +#include #include #include @@ -25,6 +28,16 @@ limitations under the License. */ namespace paddle { namespace operators { +/* Use ugly global variable, for the using in python layer side + Please refer to the layer_helper.py and get the details. + */ +static std::unordered_set InplaceOpSet = { + "sigmoid", "exp", "relu", "tanh", "sqrt", "ceil", + "floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid", +}; + +static bool IsInplace(std::string op) { return InplaceOpSet.count(op); } + template class ActivationKernel : public framework::OpKernel { @@ -60,7 +73,6 @@ class ActivationGradKernel public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); @@ -68,7 +80,6 @@ class ActivationGradKernel dX->mutable_data(context.GetPlace()); auto dout = framework::EigenVector::Flatten(*dOut); - auto x = framework::EigenVector::Flatten(*X); auto out = framework::EigenVector::Flatten(*Out); auto dx = framework::EigenVector::Flatten(*dX); auto* place = @@ -78,7 +89,16 @@ class ActivationGradKernel for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } - functor(*place, x, out, dout, dx); + bool inplace = functor.Inplace(); + if (!inplace) { + auto* X = context.Input("X"); + auto x = framework::EigenVector::Flatten(*X); + functor(*place, x, out, dout, dx); + } else { + VLOG(10) << " Inplace activation "; + auto x = framework::EigenVector::Flatten(*dX); + functor(*place, x, out, dout, dx); + } } }; @@ -89,6 +109,14 @@ struct BaseActivationFunctor { using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } + + /* NOTE(*): Output reuse X memory if X is not dependented by its Gradient. + For example, sigmoid op's gradient didn't involve x, so its output can + reuse + input memory. But abs op's gradient use x, it can not be inplaced. + gradient did use x. + */ + bool Inplace() const { return false; } }; // sigmoid(x) = 1 / (1 + exp(-x)) @@ -102,6 +130,7 @@ struct SigmoidFunctor : public BaseActivationFunctor { template struct SigmoidGradFunctor : public BaseActivationFunctor { + bool Inplace() const { return IsInplace("sigmoid"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -156,6 +185,7 @@ struct ExpFunctor : public BaseActivationFunctor { template struct ExpGradFunctor : public BaseActivationFunctor { + bool Inplace() const { return IsInplace("exp"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -174,10 +204,11 @@ struct ReluFunctor : public BaseActivationFunctor { template struct ReluGradFunctor : public BaseActivationFunctor { + bool Inplace() const { return IsInplace("relu"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * (x > static_cast(0)).template cast(); + dx.device(d) = dout * (out > static_cast(0)).template cast(); } }; @@ -192,6 +223,7 @@ struct TanhFunctor : public BaseActivationFunctor { template struct TanhGradFunctor : public BaseActivationFunctor { + bool Inplace() const { return IsInplace("tanh"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -297,6 +329,7 @@ struct SqrtFunctor : public BaseActivationFunctor { template struct SqrtGradFunctor : public BaseActivationFunctor { + bool Inplace() const { return IsInplace("sqrt"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -316,10 +349,11 @@ struct CeilFunctor : public BaseActivationFunctor { template struct ZeroGradFunctor : public BaseActivationFunctor { + bool Inplace() const { return IsInplace("ceil"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = static_cast(0) / x; + dx.device(d) = static_cast(0) / out; } }; @@ -432,6 +466,7 @@ struct ReciprocalFunctor : public BaseActivationFunctor { template struct ReciprocalGradFunctor : public BaseActivationFunctor { + bool Inplace() const { return IsInplace("reciprocal"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -531,12 +566,14 @@ struct Relu6GradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } + bool Inplace() const { return IsInplace("relu6"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * - ((x > static_cast(0)) * (x < static_cast(threshold))) - .template cast(); + dx.device(d) = + dout * + ((out > static_cast(0)) * (out < static_cast(threshold))) + .template cast(); } }; @@ -611,11 +648,12 @@ struct SoftReluGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } + bool Inplace() const { return IsInplace("soft_relu"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = static_cast(threshold); - auto temp = ((x > -tmp) * (x < tmp)).template cast().eval(); + auto temp = ((out > -tmp) * (out < tmp)).template cast().eval(); dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } }; @@ -791,7 +829,7 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } - + bool Inplace() { return IsInplace("hard_sigmoid"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 490dce19b69c3a9bfbb5639926f44ad2a94ab360..1dd66e0280c46c0624ff70e822cb6fa6f06b7aa9 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -24,12 +24,34 @@ namespace paddle { namespace operators { template -__global__ void RandomGenerator(const size_t n, const T* src, - const T* cpu_mask_data, T* mask_data, T* dst) { +__global__ void RandomGenerator(const size_t n, const int seed, + const float dropout_prob, const T* src, + T* mask_data, T* dst) { + thrust::minstd_rand rng; + rng.seed(seed); + thrust::uniform_real_distribution dist(0, 1); + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = 0; + + T mask; + T dest; for (; idx < n; idx += blockDim.x * gridDim.x) { - mask_data[idx] = cpu_mask_data[idx]; - dst[idx] = mask_data[idx] * src[idx]; + T s = src[idx]; + if (step_size == 0) { + rng.discard(idx); + step_size = blockDim.x * gridDim.x; + } else { + rng.discard(step_size); + } + if (dist(rng) < dropout_prob) { + mask = static_cast(0); + } else { + mask = static_cast(1); + } + dest = s * mask; + mask_data[idx] = mask; + dst[idx] = dest; } } @@ -56,27 +78,15 @@ class GPUDropoutKernel : public framework::OpKernel { std::random_device rnd; int seed = context.Attr("fix_seed") ? context.Attr("seed") : rnd(); - std::minstd_rand engine; - engine.seed(seed); - std::uniform_real_distribution dist(0, 1); - framework::Vector cpu_mask(size); - for (size_t i = 0; i < size; ++i) { - if (dist(engine) < dropout_prob) { - cpu_mask[i] = static_cast(0); - } else { - cpu_mask[i] = static_cast(1); - } - } int threads = 512; int grid = (x->numel() + threads - 1) / threads; RandomGenerator< T><<>>( - size, x_data, cpu_mask.CUDAData(context.GetPlace()), mask_data, - y_data); + size, seed, dropout_prob, x_data, mask_data, y_data); } else { - auto X = EigenVector::Flatten(*x); - auto Y = EigenVector::Flatten(*y); + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); Y.device(place) = X * static_cast(1.0f - dropout_prob); } } @@ -89,8 +99,6 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( dropout, ops::GPUDropoutKernel, - ops::GPUDropoutKernel, ops::GPUDropoutKernel); REGISTER_OP_CUDA_KERNEL(dropout_grad, - ops::DropoutGradKernel, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 41ca242d8f071bce6ff98951a5c273ea0bdc8ab5..0628b4b826d2730a8e3fb4842e4ae550b8c00569 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -24,7 +24,7 @@ namespace operators { using Tensor = framework::Tensor; template -using EigenVector = framework::EigenVector; +using EigenMatrix = framework::EigenMatrix; template class CPUDropoutKernel : public framework::OpKernel { @@ -60,8 +60,8 @@ class CPUDropoutKernel : public framework::OpKernel { } } } else { - auto X = EigenVector::Flatten(*x); - auto Y = EigenVector::Flatten(*y); + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); Y.device(place) = X * (1.0f - dropout_prob); @@ -81,9 +81,9 @@ class DropoutGradKernel : public framework::OpKernel { auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); - auto M = EigenVector::Flatten(*mask); - auto dX = EigenVector::Flatten(*grad_x); - auto dY = EigenVector::Flatten(*grad_y); + auto M = EigenMatrix::Reshape(*mask, 1); + auto dX = EigenMatrix::Reshape(*grad_x, 1); + auto dY = EigenMatrix::Reshape(*grad_y, 1); auto& place = *context.template device_context().eigen_device(); diff --git a/paddle/fluid/operators/dropout_op_test.cc b/paddle/fluid/operators/dropout_op_test.cc index 47ea8476748f604e7b3635b53a3b6435f930f8c9..424d273c34b7e8d70c88b591c4fe45db61465f38 100644 --- a/paddle/fluid/operators/dropout_op_test.cc +++ b/paddle/fluid/operators/dropout_op_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include #include #include // NOLINT @@ -33,16 +32,14 @@ namespace m = paddle::operators::math; USE_OP(dropout); -static paddle::framework::DDim dims = {10, 10}; - void Compare(f::Scope* scope, const p::DeviceContext& ctx) { // init auto var = scope->Var("X"); auto tensor = var->GetMutable(); - tensor->Resize(dims); + tensor->Resize({10, 10}); std::vector init; - for (int64_t i = 0; i < f::product(dims); ++i) { + for (int64_t i = 0; i < 10 * 10; ++i) { init.push_back(1.0); } @@ -51,19 +48,18 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { auto place = ctx.GetPlace(); auto out_var = scope->Var("Out"); auto out_tensor = out_var->GetMutable(); - out_tensor->Resize(dims); + out_tensor->Resize({10, 10}); out_tensor->mutable_data(place); // allocate auto mask_var = scope->Var("Mask"); auto mask_tensor = mask_var->GetMutable(); - mask_tensor->Resize(dims); + mask_tensor->Resize({10, 10}); mask_tensor->mutable_data(place); // allocate // run f::AttributeMap attrs; float dropout_prob = 0.5; - attrs.insert({"is_test", false}); - attrs.insert({"fix_seed", true}); + attrs.insert({"fix_seed", 1}); attrs.insert({"seed", 3}); attrs.insert({"dropout_prob", dropout_prob}); auto dropout_op = f::OpRegistry::CreateOp( @@ -73,7 +69,6 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { std::vector out_vec; TensorToVector(*out_tensor, ctx, &out_vec); - ctx.Wait(); std::vector std_out = { 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, @@ -88,22 +83,22 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { } } +// TODO(wyi): Due to +// https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily +// disable this test to remove the prevention of the merge of +// unrelated PRs. +/* TEST(Dropout, CPUDense) { f::Scope scope; p::CPUPlace place; p::CPUDeviceContext ctx(place); - Compare(&scope, ctx); + Compare(scope, ctx); } -// TODO(wyi, dzhwinter): Due to -// https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily -// disable this test to remove the prevention of the merge of -// unrelated PRs. -/* TEST(Dropout, GPUDense) { f::Scope scope; p::CUDAPlace place; p::CUDADeviceContext ctx(place); - Compare(&scope, ctx); + Compare(scope, ctx); } */ diff --git a/paddle/fluid/operators/mkldnn_activation_op.h b/paddle/fluid/operators/mkldnn_activation_op.h index 083d03ebe610521c5a4beb7b977a8179700bcf40..f26a165b5a59f01f864d62bbf798f4cbffa65371 100644 --- a/paddle/fluid/operators/mkldnn_activation_op.h +++ b/paddle/fluid/operators/mkldnn_activation_op.h @@ -60,7 +60,7 @@ class MKLDNNActivationGradKernel } }; -namespace { +namespace { // NOLINT framework::OpKernelType GetKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel& oper) { diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 7de7f84a3dc76195d0098d7bb9baf0461aff3575..6471eb3ab7bf05365c0bb2bf68bb74ef9044c527 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -127,6 +127,7 @@ void BindProgramDesc(pybind11::module *m) { .def("block", &pd::ProgramDesc::MutableBlock, pybind11::return_value_policy::reference) .def("num_blocks", &pd::ProgramDesc::Size) + .def("flush", &pd::ProgramDesc::Flush) .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames) .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames) .def("serialize_to_string", SerializeMessage) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 64d92cac7eca1086cd3cdcd48c668194d202e991..1f21e7abe76b2a32d6c18e5c26c4f25b65daef5b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -33,6 +33,7 @@ limitations under the License. */ #include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" @@ -461,6 +462,9 @@ All parameter, weight, gradient are variables in Paddle. self.back().set_lod(t.lod()); }); + m.def("IsInplace", + [](std::string op) -> bool { return operators::IsInplace(op); }); + m.def("op_support_gpu", OpSupportGPU); #ifdef PADDLE_WITH_CUDA m.def("get_cuda_device_count", platform::GetCUDADeviceCount); diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 86ef3e4df153dba905fb5f09e6c2030724270de8..94628270228b9e7fd32405bdcb5e11c163ba4791 100755 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -155,7 +155,7 @@ EOF function gen_dockerfile() { # Set BASE_IMAGE according to env variables if [[ ${WITH_GPU} == "ON" ]]; then - BASE_IMAGE="nvidia/cuda:8.0-cudnn5-runtime-ubuntu16.04" + BASE_IMAGE="nvidia/cuda:8.0-cudnn7-runtime-ubuntu16.04" else BASE_IMAGE="ubuntu:16.04" fi @@ -164,7 +164,7 @@ function gen_dockerfile() { DOCKERFILE_CUDNN_DSO="" if [[ ${WITH_GPU:-OFF} == 'ON' ]]; then DOCKERFILE_GPU_ENV="ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu:\${LD_LIBRARY_PATH}" - DOCKERFILE_CUDNN_DSO="RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.5 /usr/lib/x86_64-linux-gnu/libcudnn.so" + DOCKERFILE_CUDNN_DSO="RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.7 /usr/lib/x86_64-linux-gnu/libcudnn.so" fi cat <