提交 0153c21d 编写于 作者: D dzhwinter 提交者: sneaxiy

add unstack_op

上级 d292ad85
......@@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op)
op_library(unstack_op DEPS stack_op)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/unstack_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
USE_OP(stack);
REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker,
ops::UnStackOpInferShape, ops::UnStackGradOpDescMaker);
REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp,
ops::UnStackOpGradInferShape);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class UnStackOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist.");
int axis = ctx->Attrs().Get<int>("axis");
int num = ctx->Attrs().Get<int>("num");
auto x_dim = ctx->GetInputDim("X");
int rank = x_dim.size();
PADDLE_ENFORCE(axis >= -rank && axis < rank,
"Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
"Number of Outputs(Y) is wrong");
if (x_dim[axis] > 0) {
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong");
}
auto vec = framework::vectorize2int(x_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim("Y", std::vector<framework::DDim>( // NOLINT
x_dim[axis], framework::make_ddim(vec)));
}
};
class UnStackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of unstack op.");
AddOutput("Y", "The output of unstack op.").AsDuplicable();
AddAttr<int>("axis", "The axis along which Input(X) should be unstacked.")
.SetDefault(0);
AddAttr<int>("num", "The number of outputs(Y).").GreaterThan(0);
AddComment(R"DOC(
UnStack Operator.
UnStack Input(X) into several tensors along Attr(axis).
)DOC");
}
};
class UnStackOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto stack_grad_op = framework::OpRegistry::CreateOp(
"stack_grad", {{framework::GradVarName("Y"), {Input("X")}}},
{{framework::GradVarName("X"), Outputs("Y")}}, Attrs());
stack_grad_op->Run(scope, place);
}
};
class UnStackOpGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
"Number of Inputs(Y@Grad) must be larger than 0");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) must exist.");
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(Y@Grad) must be the same");
}
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE(
axis >= -(rank + 1) && axis < rank + 1,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize2int(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec));
}
};
class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("unstack_grad");
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
class UnStackGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto stack_op = framework::OpRegistry::CreateOp(
"stack", {{"X", Inputs(framework::GradVarName("Y"))}},
{{"Y", {Output(framework::GradVarName("X"))}}}, Attrs());
stack_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
if (NOT WIN32)
proto_library(profiler_proto SRCS profiler.proto DEPS framework_proto)
py_proto_compile(profiler_py_proto SRCS profiler.proto)
......@@ -10,6 +11,7 @@ add_custom_command(TARGET profiler_py_proto POST_BUILD
COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/profiler
COMMENT "Copy generated python proto into directory paddle/fluid/proto/profiler."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif(NOT WIN32)
if(WITH_GPU)
nv_library(enforce SRCS enforce.cc)
......@@ -58,9 +60,12 @@ cc_test(init_test SRCS init_test.cc DEPS device_context)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context)
if (NOT WIN32)
cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS})
cc_library(profiler SRCS profiler.cc DEPS device_context device_tracer)
cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
endif(NOT WIN32)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
......
......@@ -22,9 +22,13 @@ limitations under the License. */
#ifdef __APPLE__
#include <sys/sysctl.h>
#include <sys/types.h>
#elif defined(_WIN32)
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#include <windows.h>
#else
#include <unistd.h>
#endif
#endif // _WIN32
#include <algorithm>
#include "gflags/gflags.h"
......@@ -32,16 +36,20 @@ limitations under the License. */
DEFINE_double(fraction_of_cpu_memory_to_use, 1,
"Default use 100% of CPU memory for PaddlePaddle,"
"reserve the rest for page tables, etc");
#if !defined(_WIN32)
DEFINE_uint64(initial_cpu_memory_in_mb,
#ifdef PADDLE_WITH_MKLDNN
/* Aligned with mozga-intel, MKLDNN need at least 5000 MB
* to obtain the best performance*/
5000,
5000ul,
#else
500,
500ul,
#endif
"Initial CPU memory for PaddlePaddle, in MD unit.");
#else
DEFINE_uint64(initial_cpu_memory_in_mb, 500ul,
"Initial CPU memory for PaddlePaddle, in MD unit.");
#endif // !defined(_WIN32)
DEFINE_double(
fraction_of_cuda_pinned_memory_to_use, 0.5,
......@@ -60,6 +68,11 @@ inline size_t CpuTotalPhysicalMemory() {
size_t len = sizeof(size);
if (sysctl(mib, 2, &size, &len, NULL, 0) == 0) return (size_t)size;
return 0L;
#elif defined(_WIN32)
MEMORYSTATUSEX sMeminfo;
sMeminfo.dwLength = sizeof(sMeminfo);
GlobalMemoryStatusEx(&sMeminfo);
return sMeminfo.ullTotalPhys;
#else
int64_t pages = sysconf(_SC_PHYS_PAGES);
int64_t page_size = sysconf(_SC_PAGE_SIZE);
......
......@@ -13,7 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if !defined(_WIN32)
#include <sys/time.h>
#else
#include <windows.h>
#endif // !_WIN32
#include <time.h>
#include <chrono> // NOLINT
#include <string>
......@@ -27,12 +32,15 @@ namespace platform {
///////////////////////
// WARN: Under Development. Don't depend on it yet.
//////////////////////
#if !defined(_WIN32)
inline uint64_t PosixInNsec() {
struct timeval tv;
gettimeofday(&tv, nullptr);
return 1000 * (static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec);
}
#else
inline uint64_t PosixInNsec() { return static_cast<uint64_t>(0); }
#endif // !_WIN32
// DeviceTracer performs the following tasks:
// 1. Register cuda callbacks for various events: kernel, memcpy, etc.
......
......@@ -16,7 +16,9 @@ if (CUPTI_FOUND)
list(APPEND CUDA_SRCS cupti.cc)
endif(CUPTI_FOUND)
nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader)
if (NOT WIN32)
cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc)
endif(NOT WIN32)
if (WITH_MKLML)
cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml)
endif()
......
......@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include <dlfcn.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
......@@ -23,6 +21,7 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/fluid/platform/dynload/cupti_lib_path.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/port.h"
DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, "
......
......@@ -18,6 +18,11 @@ limitations under the License. */
#include <cxxabi.h> // for __cxa_demangle
#endif // __GNUC__
#if defined(_WIN32)
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h
#endif
#ifdef PADDLE_WITH_CUDA
#include <cublas_v2.h>
#include <cudnn.h>
......@@ -117,7 +122,12 @@ struct EOFException : public std::exception {
// always forces branch prediction of true.
// This generates faster binary code. __builtin_expect is since C++11.
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition == 0)
#endif
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
......@@ -230,6 +240,7 @@ inline void throw_on_error(T e) {
throw_on_error(e, "");
}
#if !defined(_WIN32)
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
......@@ -248,15 +259,28 @@ inline void throw_on_error(T e) {
__FILE__, __LINE__); \
} \
} while (false)
#else
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__);
#endif
#define PADDLE_THROW_EOF() \
do { \
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
__LINE__); \
} while (false)
#else
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__)
#endif // REPLACE_ENFORCE_GLOG
#else // !_WIN32
// disable enforce, caused by the varardic macro exception error
#define PADDLE_THROW(x) \
do { \
throw std::make_exception_ptr( \
std::runtime_error("Windows disable the enforce.")); \
} while (false)
#define PADDLE_ENFORCE(x, ...) x
#endif // !_WIN32
/*
* Some enforce helpers here, usage:
* int a = 1;
......
......@@ -69,6 +69,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx);
void PopEvent(const std::string& name, const DeviceContext* dev_ctx);
#if !defined(_WIN32)
struct RecordEvent {
RecordEvent(const std::string& name, const DeviceContext* dev_ctx);
......@@ -94,6 +95,15 @@ struct RecordBlock {
std::string name_;
uint64_t start_ns_;
};
#else
// windows do not support profiler temporarily.
struct RecordEvent {
RecordEvent(const std::string& name, const DeviceContext* dev_ctx) {}
};
struct RecordBlock {
explicit RecordBlock(int block_id) {}
};
#endif
// Return the event list of all threads. Assumed the returned value calls
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
......
......@@ -105,6 +105,7 @@ __all__ = [
'flatten',
'sequence_mask',
'stack',
'unstack',
]
......@@ -5601,3 +5602,44 @@ def stack(x, axis=0):
type='stack', inputs={'X': x}, outputs={'Y': out},
attrs={'axis': axis})
return out
def unstack(x, axis=0, num=None):
"""
**UnStack Layer**
This layer unstacks input :code:`x` into several tensors along axis.
If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x)`.
If :code:`num` is None, it would be inferred from :code:`x.shape[axis]`,
and if :code:`x.shape[axis]` <= 0 or is unknown, :code:`ValueError` is
raised.
Args:
x (Variable): Input variable.
axis (int): The axis along which the input is unstacked.
num (int|None): The number of output variables.
Returns:
list(Variable): The unstacked variables.
"""
helper = LayerHelper('unstack', **locals())
if num is None:
if axis is None or x.shape[axis] <= 0:
raise ValueError('unknown unstack number')
else:
num = x.shape[axis]
outs = []
for _ in num:
outs.append(helper.create_tmp_variable(x.dtype))
helper.append_op(
type='unstack',
inputs={'X': [x]},
outputs={'Y': outs},
attrs={'axis': axis,
'num': num})
return outs
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from op_test import OpTest
import numpy as np
import unittest
class TestUnStackOpBase(OpTest):
def initDefaultParameters(self):
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = 'float32'
def initParameters(self):
pass
def get_y_names(self):
y_names = []
for i in range(self.input_dim[self.axis]):
y_names.append('y{}'.format(i))
return y_names
def setUp(self):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'unstack'
self.x = np.random.random(size=self.input_dim).astype(self.dtype)
outs = np.split(self.x, self.input_dim[self.axis], self.axis)
new_shape = list(self.input_dim)
del new_shape[self.axis]
y_names = self.get_y_names()
tmp = []
for i in range(self.input_dim[self.axis]):
tmp.append((y_names[i], np.reshape(outs[i], new_shape)))
self.inputs = {'X': self.x}
self.outputs = {'Y': tmp}
self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad('X', self.get_y_names())
class TestStackOp3(TestUnStackOpBase):
def initParameters(self):
self.axis = -1
class TestStackOp4(TestUnStackOpBase):
def initParameters(self):
self.axis = -3
class TestStackOp5(TestUnStackOpBase):
def initParameters(self):
self.axis = 1
class TestStackOp6(TestUnStackOpBase):
def initParameters(self):
self.axis = 2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册