diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 8da0aaaafeb151e8f1900bc66f06e771c857fc00..e73d31562a93af4e9bdb3d5806e9182d0eea8167 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) +op_library(unstack_op DEPS stack_op) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ff3249cc333231a0624cd5aab9603a6a75f4480 --- /dev/null +++ b/paddle/fluid/operators/unstack_op.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/unstack_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +USE_OP(stack); + +REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker, + ops::UnStackOpInferShape, ops::UnStackGradOpDescMaker); + +REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp, + ops::UnStackOpGradInferShape); diff --git a/paddle/fluid/operators/unstack_op.h b/paddle/fluid/operators/unstack_op.h new file mode 100644 index 0000000000000000000000000000000000000000..348a1038804ccb2551e5f729cc1a38bcef1511f5 --- /dev/null +++ b/paddle/fluid/operators/unstack_op.h @@ -0,0 +1,135 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class UnStackOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist."); + + int axis = ctx->Attrs().Get("axis"); + int num = ctx->Attrs().Get("num"); + auto x_dim = ctx->GetInputDim("X"); + int rank = x_dim.size(); + PADDLE_ENFORCE(axis >= -rank && axis < rank, + "Attr(axis) must be inside [-rank, rank), where rank = %d", + rank); + if (axis < 0) axis += rank; + + PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast(num), + "Number of Outputs(Y) is wrong"); + if (x_dim[axis] > 0) { + PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong"); + } + auto vec = framework::vectorize2int(x_dim); + vec.erase(vec.begin() + axis); + ctx->SetOutputsDim("Y", std::vector( // NOLINT + x_dim[axis], framework::make_ddim(vec))); + } +}; + +class UnStackOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of unstack op."); + AddOutput("Y", "The output of unstack op.").AsDuplicable(); + AddAttr("axis", "The axis along which Input(X) should be unstacked.") + .SetDefault(0); + AddAttr("num", "The number of outputs(Y).").GreaterThan(0); + AddComment(R"DOC( + UnStack Operator. + + UnStack Input(X) into several tensors along Attr(axis). + )DOC"); + } +}; + +class UnStackOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto stack_grad_op = framework::OpRegistry::CreateOp( + "stack_grad", {{framework::GradVarName("Y"), {Input("X")}}}, + {{framework::GradVarName("X"), Outputs("Y")}}, Attrs()); + stack_grad_op->Run(scope, place); + } +}; + +class UnStackOpGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0, + "Number of Inputs(Y@Grad) must be larger than 0"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad) must exist."); + + auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y")); + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], + "Dims of all Inputs(Y@Grad) must be the same"); + } + + int axis = ctx->Attrs().Get("axis"); + int rank = input_dims[0].size(); + PADDLE_ENFORCE( + axis >= -(rank + 1) && axis < rank + 1, + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); + if (axis < 0) axis += (rank + 1); + + auto vec = framework::vectorize2int(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec)); + } +}; + +class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("unstack_grad"); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +class UnStackGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto stack_op = framework::OpRegistry::CreateOp( + "stack", {{"X", Inputs(framework::GradVarName("Y"))}}, + {{"Y", {Output(framework::GradVarName("X"))}}}, Attrs()); + stack_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 75d3856d0dda8b5bf6a4fa11954a611bf140c9bc..e25efebe6c3555958f4f75e2b87b7dc45d4a4177 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -1,3 +1,4 @@ +if (NOT WIN32) proto_library(profiler_proto SRCS profiler.proto DEPS framework_proto) py_proto_compile(profiler_py_proto SRCS profiler.proto) @@ -10,6 +11,7 @@ add_custom_command(TARGET profiler_py_proto POST_BUILD COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/profiler COMMENT "Copy generated python proto into directory paddle/fluid/proto/profiler." WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) +endif(NOT WIN32) if(WITH_GPU) nv_library(enforce SRCS enforce.cc) @@ -58,9 +60,12 @@ cc_test(init_test SRCS init_test.cc DEPS device_context) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) + +if (NOT WIN32) cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS}) cc_library(profiler SRCS profiler.cc DEPS device_context device_tracer) cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) +endif(NOT WIN32) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index fcd658d67cf4551dbdb9696ef49b5ab3cc58bf95..2880c09263f10e9c624e11b77188171f48d9db28 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -22,9 +22,13 @@ limitations under the License. */ #ifdef __APPLE__ #include #include + +#elif defined(_WIN32) +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#include #else #include -#endif +#endif // _WIN32 #include #include "gflags/gflags.h" @@ -32,16 +36,20 @@ limitations under the License. */ DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," "reserve the rest for page tables, etc"); - +#if !defined(_WIN32) DEFINE_uint64(initial_cpu_memory_in_mb, #ifdef PADDLE_WITH_MKLDNN /* Aligned with mozga-intel, MKLDNN need at least 5000 MB * to obtain the best performance*/ - 5000, + 5000ul, #else - 500, + 500ul, #endif "Initial CPU memory for PaddlePaddle, in MD unit."); +#else +DEFINE_uint64(initial_cpu_memory_in_mb, 500ul, + "Initial CPU memory for PaddlePaddle, in MD unit."); +#endif // !defined(_WIN32) DEFINE_double( fraction_of_cuda_pinned_memory_to_use, 0.5, @@ -60,6 +68,11 @@ inline size_t CpuTotalPhysicalMemory() { size_t len = sizeof(size); if (sysctl(mib, 2, &size, &len, NULL, 0) == 0) return (size_t)size; return 0L; +#elif defined(_WIN32) + MEMORYSTATUSEX sMeminfo; + sMeminfo.dwLength = sizeof(sMeminfo); + GlobalMemoryStatusEx(&sMeminfo); + return sMeminfo.ullTotalPhys; #else int64_t pages = sysconf(_SC_PHYS_PAGES); int64_t page_size = sysconf(_SC_PAGE_SIZE); diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index 322996fb4f54d34ebbb034a6e1de420e9c532545..f59fc40b71699a790978e22fd7e26da8d4d94c5f 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -13,7 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#if !defined(_WIN32) #include +#else +#include +#endif // !_WIN32 + #include #include // NOLINT #include @@ -27,12 +32,15 @@ namespace platform { /////////////////////// // WARN: Under Development. Don't depend on it yet. ////////////////////// - +#if !defined(_WIN32) inline uint64_t PosixInNsec() { struct timeval tv; gettimeofday(&tv, nullptr); return 1000 * (static_cast(tv.tv_sec) * 1000000 + tv.tv_usec); } +#else +inline uint64_t PosixInNsec() { return static_cast(0); } +#endif // !_WIN32 // DeviceTracer performs the following tasks: // 1. Register cuda callbacks for various events: kernel, memcpy, etc. diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 07159d4a12ef4b628f7705ed206d3334be46dfc8..5939c500c946c44579d1de645ac9700c7701a4e9 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -16,7 +16,9 @@ if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) endif(CUPTI_FOUND) nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) +if (NOT WIN32) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) +endif(NOT WIN32) if (WITH_MKLML) cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) endif() diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 93bf7c13516ffa4baca6a30f1daf946939726d85..4fbfa6354ab45fed4839227a2a4be8fe147e5fd9 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/dynload/dynamic_loader.h" -#include - #include #include // NOLINT #include @@ -23,6 +21,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/dynload/cupti_lib_path.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/port.h" DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index a76ba75f9eeb8c3f42fbf7254f629b0960a8f2d8..61a653d9313daff96d39c08e80f17d7e33acceb1 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -18,6 +18,11 @@ limitations under the License. */ #include // for __cxa_demangle #endif // __GNUC__ +#if defined(_WIN32) +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h +#endif + #ifdef PADDLE_WITH_CUDA #include #include @@ -117,7 +122,12 @@ struct EOFException : public std::exception { // always forces branch prediction of true. // This generates faster binary code. __builtin_expect is since C++11. // For more details, please check https://stackoverflow.com/a/43870188/724872. +#if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#else +// there is no equivalent intrinsics in msvc. +#define UNLIKELY(condition) (condition == 0) +#endif template inline typename std::enable_if::type throw_on_error( @@ -230,6 +240,7 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } +#if !defined(_WIN32) #define PADDLE_THROW(...) \ do { \ throw ::paddle::platform::EnforceNotMet( \ @@ -248,15 +259,28 @@ inline void throw_on_error(T e) { __FILE__, __LINE__); \ } \ } while (false) -#else -#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__); -#endif #define PADDLE_THROW_EOF() \ do { \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ __LINE__); \ } while (false) + +#else +#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__) +#endif // REPLACE_ENFORCE_GLOG + +#else // !_WIN32 +// disable enforce, caused by the varardic macro exception error +#define PADDLE_THROW(x) \ + do { \ + throw std::make_exception_ptr( \ + std::runtime_error("Windows disable the enforce.")); \ + } while (false) + +#define PADDLE_ENFORCE(x, ...) x +#endif // !_WIN32 + /* * Some enforce helpers here, usage: * int a = 1; diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index c99d9c807d1bfb45d1ce0725b84b9fff09049511..38630686f7cf3c669373f941d989adf11ba6cfe6 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -69,6 +69,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx); void PopEvent(const std::string& name, const DeviceContext* dev_ctx); +#if !defined(_WIN32) struct RecordEvent { RecordEvent(const std::string& name, const DeviceContext* dev_ctx); @@ -94,6 +95,15 @@ struct RecordBlock { std::string name_; uint64_t start_ns_; }; +#else +// windows do not support profiler temporarily. +struct RecordEvent { + RecordEvent(const std::string& name, const DeviceContext* dev_ctx) {} +}; +struct RecordBlock { + explicit RecordBlock(int block_id) {} +}; +#endif // Return the event list of all threads. Assumed the returned value calls // event_lists, event_lists[i][j] represents the j-th Event of i-th thread. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 66b776c08e4158e8ce7df6c66f052a6925c043e8..44416381c70f4bbdc6fe91a475da27890af900a5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -105,6 +105,7 @@ __all__ = [ 'flatten', 'sequence_mask', 'stack', + 'unstack', ] @@ -5601,3 +5602,44 @@ def stack(x, axis=0): type='stack', inputs={'X': x}, outputs={'Y': out}, attrs={'axis': axis}) return out + + +def unstack(x, axis=0, num=None): + """ + **UnStack Layer** + + This layer unstacks input :code:`x` into several tensors along axis. + + If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x)`. + If :code:`num` is None, it would be inferred from :code:`x.shape[axis]`, + and if :code:`x.shape[axis]` <= 0 or is unknown, :code:`ValueError` is + raised. + + Args: + x (Variable): Input variable. + axis (int): The axis along which the input is unstacked. + num (int|None): The number of output variables. + + Returns: + list(Variable): The unstacked variables. + + """ + + helper = LayerHelper('unstack', **locals()) + if num is None: + if axis is None or x.shape[axis] <= 0: + raise ValueError('unknown unstack number') + else: + num = x.shape[axis] + + outs = [] + for _ in num: + outs.append(helper.create_tmp_variable(x.dtype)) + + helper.append_op( + type='unstack', + inputs={'X': [x]}, + outputs={'Y': outs}, + attrs={'axis': axis, + 'num': num}) + return outs diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7cbac8928ec40dc3e1c0e91e7779ec9ec978d884 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import numpy as np +import unittest + + +class TestUnStackOpBase(OpTest): + def initDefaultParameters(self): + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'float32' + + def initParameters(self): + pass + + def get_y_names(self): + y_names = [] + for i in range(self.input_dim[self.axis]): + y_names.append('y{}'.format(i)) + return y_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'unstack' + self.x = np.random.random(size=self.input_dim).astype(self.dtype) + + outs = np.split(self.x, self.input_dim[self.axis], self.axis) + new_shape = list(self.input_dim) + del new_shape[self.axis] + y_names = self.get_y_names() + tmp = [] + for i in range(self.input_dim[self.axis]): + tmp.append((y_names[i], np.reshape(outs[i], new_shape))) + + self.inputs = {'X': self.x} + self.outputs = {'Y': tmp} + self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad('X', self.get_y_names()) + + +class TestStackOp3(TestUnStackOpBase): + def initParameters(self): + self.axis = -1 + + +class TestStackOp4(TestUnStackOpBase): + def initParameters(self): + self.axis = -3 + + +class TestStackOp5(TestUnStackOpBase): + def initParameters(self): + self.axis = 1 + + +class TestStackOp6(TestUnStackOpBase): + def initParameters(self): + self.axis = 2 + + +if __name__ == '__main__': + unittest.main()