提交 b548ecbc 编写于 作者: X Xin Pan 提交者: sneaxiy

add stack_op

// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace framework {
template <typename T, size_t N>
class Array {
static_assert(N > 0, "The size of array must be larger than 0");
public:
HOSTDEVICE Array() {}
HOSTDEVICE explicit Array(const T &val) {
for (size_t i = 0; i < N; ++i) data_[i] = val;
}
HOSTDEVICE const T *Get() const { return data_; }
HOSTDEVICE T *GetMutable() { return data_; }
HOSTDEVICE T &operator[](size_t index) { return data_[index]; }
HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; }
HOSTDEVICE constexpr size_t size() const { return N; }
private:
T data_[N];
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/stack_op.h"
namespace paddle {
namespace operators {
struct CPUStackFunctor {
template <typename DeviceContext, typename T>
void operator()(const DeviceContext& ctx, const std::vector<const T*>& x,
T* y, int pre, int n, int post) const {
int total_num = pre * post * n;
for (int idx = 0; idx < total_num; ++idx) {
int i = idx / (n * post);
int which_x = idx / post - i * n;
int x_index = i * post + idx % post;
y[idx] = x[which_x][x_index];
}
}
};
struct CPUStackGradFunctor {
template <typename DeviceContext, typename T>
void operator()(const DeviceContext& ctx, std::vector<T*>& dx, // NOLINT
const T* dy, int pre, int n, int post) const {
int total_num = pre * post * n;
for (int idx = 0; idx < total_num; ++idx) {
int i = idx / (n * post);
int which_x = idx / post - i * n;
int x_index = i * post + idx % post;
dx[which_x][x_index] = dy[idx];
}
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
ops::StackGradOpDescMaker);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
REGISTER_OP_CPU_KERNEL(
stack,
ops::StackKernel<plat::CPUDeviceContext, float, ops::CPUStackFunctor>,
ops::StackKernel<plat::CPUDeviceContext, double, ops::CPUStackFunctor>);
REGISTER_OP_CPU_KERNEL(stack_grad,
ops::StackGradKernel<plat::CPUDeviceContext, float,
ops::CPUStackGradFunctor>,
ops::StackGradKernel<plat::CPUDeviceContext, double,
ops::CPUStackGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/operators/stack_op.h"
namespace paddle {
namespace operators {
template <typename T, typename VecXType>
__global__ void StackCUDAKernel(VecXType x, T* y, int total_num, int n,
int post) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < total_num) {
int i = idx / (n * post);
int which_x = idx / post - i * n;
int x_index = i * post + idx % post;
y[idx] = x[which_x][x_index];
}
}
template <typename T, typename VecDxType>
__global__ void StackGradCUDAKernel(VecDxType dx, const T* dy, int total_num,
int n, int post) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < total_num) {
int i = idx / (n * post);
int which_x = idx / post - i * n;
int x_index = i * post + idx % post;
dx[which_x][x_index] = dy[idx];
}
}
struct GPUStackFunctor {
template <typename DeviceContext, typename T>
void operator()(const DeviceContext& ctx, const std::vector<const T*>& x,
T* y, int pre, int n, int post) const {
int total_num = pre * post * n;
int threads = 512;
int grid = (total_num + threads - 1) / threads;
constexpr auto kMaxThreshold = 16;
if (n <= kMaxThreshold) {
framework::Array<const T*, kMaxThreshold> arr;
for (int i = 0; i < n; ++i) arr[i] = x[i];
StackCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(arr, y, total_num, n,
post);
} else {
VLOG(10) << "Stack more than " << kMaxThreshold
<< " tensors may be slow on GPU.";
thrust::device_vector<const T*> dev_x(x);
StackCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(dev_x.data().get(), y,
total_num, n, post);
}
}
};
struct GPUStackGradFunctor {
template <typename DeviceContext, typename T>
void operator()(const DeviceContext& ctx, std::vector<T*>& dx, // NOLINT
const T* dy, int pre, int n, int post) const {
int total_num = pre * post * n;
int threads = 512;
int grid = (total_num + threads - 1) / threads;
constexpr auto kMaxThreshold = 16;
if (n <= kMaxThreshold) {
framework::Array<T*, kMaxThreshold> arr;
for (int i = 0; i < n; ++i) arr[i] = dx[i];
StackGradCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(
arr, dy, total_num, n, post);
} else {
VLOG(10) << "Stack more than " << kMaxThreshold
<< " tensors may be slow on GPU.";
thrust::device_vector<T*> dev_dx(dx);
StackGradCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(
dev_dx.data().get(), dy, total_num, n, post);
}
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
stack,
ops::StackKernel<plat::CUDADeviceContext, float, ops::GPUStackFunctor>,
ops::StackKernel<plat::CUDADeviceContext, double, ops::GPUStackFunctor>);
REGISTER_OP_CUDA_KERNEL(stack_grad,
ops::StackGradKernel<plat::CUDADeviceContext, float,
ops::GPUStackGradFunctor>,
ops::StackGradKernel<plat::CUDADeviceContext, double,
ops::GPUStackGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
inline void GetPrePostForStackOp(const framework::DDim &dim, int axis, int *pre,
int *post) {
*pre = 1;
for (auto i = 0; i < axis; ++i) (*pre) *= dim[i];
*post = 1;
for (auto i = axis; i < dim.size(); ++i) (*post) *= dim[i];
}
class StackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0,
"Number of Inputs(X) must be larger than 0");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist.");
auto input_dims = ctx->GetInputsDim("X");
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(X) must be the same");
}
// Only lod of X[0] would be shared with Y
ctx->ShareLoD("X", /*->*/ "Y");
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE(
axis >= -(rank + 1) && axis < rank + 1,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize2int(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim("Y", framework::make_ddim(vec));
}
};
class StackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of stack op.").AsDuplicable();
AddOutput("Y", "The output of stack op.");
AddAttr<int>("axis",
"The axis along which all of the Inputs(X) should be stacked.")
.SetDefault(0);
AddComment(R"DOC(
Stack Operator.
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
)DOC");
}
};
template <typename DeviceContext, typename T, typename Functor>
class StackKernel : public framework::OpKernel<T> {
using Tensor = framework::LoDTensor;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto x = ctx.MultiInput<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = static_cast<int>(x.size());
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
std::vector<const T *> x_datas(n);
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();
int pre = 1, post = 1;
auto &dim = x[0]->dims();
for (auto i = 0; i < axis; ++i) pre *= dim[i];
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
Functor functor;
functor(ctx.template device_context<DeviceContext>(), x_datas, y_data, pre,
n, post);
}
};
class StackOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@Grad) must exist.");
int axis = ctx->Attrs().Get<int>("axis");
auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = dy_dim.size();
PADDLE_ENFORCE(axis >= -rank && axis < rank,
"Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(),
static_cast<size_t>(dy_dim[axis]),
"Number of Outputs(X@Grad) is wrong");
auto vec = framework::vectorize2int(dy_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim(
framework::GradVarName("X"),
std::vector<framework::DDim>(dy_dim[axis], framework::make_ddim(vec)));
}
};
class StackGradOpDescMaker
: public framework::
SingleGradOpDescMaker /*framework::GradOpDescMakerBase*/ {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
/*
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator ()() const override {
auto x_grads = InputGrad("X", false);
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Y");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string& x_grad) {
auto* grad_op = new framework::OpDesc();
grad_op->SetInput("X", og);
grad_op->SetOutput("Y", {x_grad});
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
});
return grad_ops;
}
*/
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("stack_grad");
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
op->SetAttrMap(Attrs());
return op;
}
};
template <typename DeviceContext, typename T, typename GradFunctor>
class StackGradKernel : public framework::OpKernel<T> {
using Tensor = framework::LoDTensor;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += dy->dims().size();
int n = dy->dims()[axis];
std::vector<T *> dx_datas(n); // NOLINT
for (int i = 0; i < n; i++)
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
auto dy_data = dy->data<T>();
int pre = 1;
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
int post = dy->numel() / (n * pre);
GradFunctor functor;
functor(ctx.template device_context<DeviceContext>(), dx_datas, dy_data,
pre, n, post);
}
};
} // namespace operators
} // namespace paddle
......@@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory.");
auto ctx = executor.Prepare(*program, block->ID());
while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
executor.Run(*program, &current_scope, block->ID(),
false /*create_local_scope*/);
executor.RunPreparedContext(ctx.get(), &current_scope, false);
}
}
};
......@@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase {
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
auto ctx = executor.Prepare(*program, block->ID());
auto *step_scopes =
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
......@@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase {
}
}
}
executor.Run(*program, *cur_scope_iter, block->ID(), false);
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false);
auto &pg_names = Outputs(kXGRAD);
auto &p_names = Inputs(kX);
......
......@@ -29,80 +29,21 @@ from .. import unique_name
from functools import reduce
__all__ = [
'fc',
'embedding',
'dynamic_lstm',
'dynamic_lstmp',
'dynamic_gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'cos_sim',
'cross_entropy',
'square_error_cost',
'chunk_eval',
'sequence_conv',
'conv2d',
'conv3d',
'sequence_pool',
'sequence_softmax',
'softmax',
'pool2d',
'pool3d',
'batch_norm',
'beam_search_decode',
'conv2d_transpose',
'conv3d_transpose',
'sequence_expand',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_prod',
'sequence_first_step',
'sequence_last_step',
'dropout',
'split',
'ctc_greedy_decoder',
'edit_distance',
'l2_normalize',
'matmul',
'topk',
'warpctc',
'sequence_reshape',
'transpose',
'im2sequence',
'nce',
'hsigmoid',
'beam_search',
'row_conv',
'multiplex',
'layer_norm',
'softmax_with_cross_entropy',
'smooth_l1',
'one_hot',
'autoincreased_step_counter',
'reshape',
'lod_reset',
'lrn',
'pad',
'label_smooth',
'roi_pool',
'dice_loss',
'image_resize',
'image_resize_short',
'resize_bilinear',
'gather',
'scatter',
'random_crop',
'mean_iou',
'relu',
'log',
'crop',
'rank_loss',
'prelu',
'flatten',
'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru',
'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy',
'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d', 'conv3d',
'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', 'pool3d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'conv3d_transpose',
'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max',
'reduce_min', 'reduce_prod', 'sequence_first_step', 'sequence_last_step',
'dropout', 'split', 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize',
'matmul', 'topk', 'warpctc', 'sequence_reshape', 'transpose', 'im2sequence',
'nce', 'hsigmoid', 'beam_search', 'row_conv', 'multiplex', 'layer_norm',
'softmax_with_cross_entropy', 'smooth_l1', 'one_hot',
'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad',
'label_smooth', 'roi_pool', 'dice_loss', 'image_resize',
'image_resize_short', 'resize_bilinear', 'gather', 'scatter', 'random_crop',
'mean_iou', 'relu', 'log', 'crop', 'rank_loss', 'prelu', 'flatten', 'stack'
]
......@@ -5517,3 +5458,16 @@ def flatten(x, axis=1, name=None):
outputs={'Out': out},
attrs={"axis": axis})
return out
def stack(x, axis=0):
helper = LayerHelper('stack', **locals())
axis = 0 if axis is None else axis
if not isinstance(x, list) and not isinstance(x, tuple):
x = [x]
out = helper.create_tmp_variable(x[0].dtype)
helper.append_op(
type='stack', inputs={'X': x}, outpus={'Y': out}, attrs={'axis': axis})
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from op_test import OpTest
import numpy as np
import unittest
class TestStackOpBase(OpTest):
def initDefaultParameters(self):
self.num_inputs = 4
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = 'float32'
def initParameters(self):
pass
def get_x_names(self):
x_names = []
for i in range(self.num_inputs):
x_names.append('x{}'.format(i))
return x_names
def setUp(self):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'stack'
self.x = []
for i in range(self.num_inputs):
self.x.append(
np.random.random(size=self.input_dim).astype(self.dtype))
tmp = []
x_names = self.get_x_names()
for i in range(self.num_inputs):
tmp.append((x_names[i], self.x[i]))
self.inputs = {'X': tmp}
self.outputs = {'Y': np.stack(self.x, axis=self.axis)}
self.attrs = {'axis': self.axis}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(self.get_x_names(), 'Y')
class TestStackOp1(TestStackOpBase):
def initParameters(self):
self.num_inputs = 16
class TestStackOp2(TestStackOpBase):
def initParameters(self):
self.num_inputs = 20
class TestStackOp3(TestStackOpBase):
def initParameters(self):
self.axis = -1
class TestStackOp4(TestStackOpBase):
def initParameters(self):
self.axis = -4
class TestStackOp5(TestStackOpBase):
def initParameters(self):
self.axis = 1
class TestStackOp6(TestStackOpBase):
def initParameters(self):
self.axis = 3
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册