提交 2a81c367 编写于 作者: Z zhongpu 提交者: Jiabin Yang

add kernel for unstack_op, test=develop (#19538)

* add kernel for unstack_op, test=develop

* add kernel for unstack_op, test=develop

* add kernel for unstack_op, test=develop

* adjust the code format, test=develop

* modify some comment, test=develop
上级 00d5375e
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unstack_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
class UnStackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must exist.");
int axis = ctx->Attrs().Get<int>("axis");
int num = ctx->Attrs().Get<int>("num");
auto x_dim = ctx->GetInputDim("X");
int rank = x_dim.size();
PADDLE_ENFORCE_GE(
axis, -rank, "Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
PADDLE_ENFORCE_LT(
axis, rank, "Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
"Number of Outputs(Y) is wrong");
if (x_dim[axis] > 0) {
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong");
}
auto vec = framework::vectorize<int>(x_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim("Y", std::vector<framework::DDim>( // NOLINT
x_dim[axis], framework::make_ddim(vec)));
}
};
class UnStackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of unstack op.");
AddOutput("Y", "The output of unstack op.").AsDuplicable();
AddAttr<int>("axis", "The axis along which Input(X) should be unstacked.")
.SetDefault(0);
AddAttr<int>("num", "The number of outputs(Y).").GreaterThan(0);
AddComment(R"DOC(
UnStack Operator.
UnStack Input(X) into several tensors along Attr(axis).
)DOC");
}
};
class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("unstack_grad");
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
class UnStackGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
"Number of Inputs(Y@Grad) must be larger than 0");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
"Output(X@Grad) must exist.");
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(Y@Grad) must be the same");
}
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE_GE(
axis, -(rank + 1),
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
PADDLE_ENFORCE_LT(
axis, rank + 1,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize<int>(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
USE_OP(stack);
REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker,
ops::UnStackOpInferShape, ops::UnStackGradOpDescMaker);
ops::UnStackGradOpDescMaker);
REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp);
REGISTER_OP_CPU_KERNEL(unstack,
ops::UnStackKernel<plat::CPUDeviceContext, float>,
ops::UnStackKernel<plat::CPUDeviceContext, double>,
ops::UnStackKernel<plat::CPUDeviceContext, int>,
ops::UnStackKernel<plat::CPUDeviceContext, int64_t>);
REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp,
ops::UnStackOpGradInferShape);
REGISTER_OP_CPU_KERNEL(unstack_grad,
ops::UnStackGradKernel<plat::CPUDeviceContext, float>,
ops::UnStackGradKernel<plat::CPUDeviceContext, double>,
ops::UnStackGradKernel<plat::CPUDeviceContext, int>,
ops::UnStackGradKernel<plat::CPUDeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unstack_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
unstack, ops::UnStackKernel<plat::CUDADeviceContext, float>,
ops::UnStackKernel<plat::CUDADeviceContext, double>,
ops::UnStackKernel<plat::CUDADeviceContext, int>,
ops::UnStackKernel<plat::CUDADeviceContext, int64_t>,
ops::UnStackKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
unstack_grad, ops::UnStackGradKernel<plat::CUDADeviceContext, float>,
ops::UnStackGradKernel<plat::CUDADeviceContext, double>,
ops::UnStackGradKernel<plat::CUDADeviceContext, int>,
ops::UnStackGradKernel<plat::CUDADeviceContext, int64_t>,
ops::UnStackGradKernel<plat::CUDADeviceContext, plat::float16>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#ifdef __NVCC__
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h"
#endif
namespace paddle {
namespace operators {
class UnStackOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist.");
int axis = ctx->Attrs().Get<int>("axis");
int num = ctx->Attrs().Get<int>("num");
auto x_dim = ctx->GetInputDim("X");
int rank = x_dim.size();
PADDLE_ENFORCE(axis >= -rank && axis < rank,
"Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
"Number of Outputs(Y) is wrong");
if (x_dim[axis] > 0) {
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong");
}
auto vec = framework::vectorize<int>(x_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim("Y", std::vector<framework::DDim>( // NOLINT
x_dim[axis], framework::make_ddim(vec)));
}
};
template <typename VecXType, typename T>
struct StackFunctor {
HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post)
: x_(x), y_(y), n_(n), post_(post) {}
class UnStackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of unstack op.");
AddOutput("Y", "The output of unstack op.").AsDuplicable();
AddAttr<int>("axis", "The axis along which Input(X) should be unstacked.")
.SetDefault(0);
AddAttr<int>("num", "The number of outputs(Y).").GreaterThan(0);
AddComment(R"DOC(
UnStack Operator.
UnStack Input(X) into several tensors along Attr(axis).
)DOC");
HOSTDEVICE void operator()(int idx) {
int i = idx / (n_ * post_);
int which_x = idx / post_ - i * n_;
int x_index = i * post_ + idx % post_;
y_[idx] = x_[which_x][x_index];
}
};
class UnStackOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto stack_grad_op = framework::OpRegistry::CreateOp(
"stack_grad", {{framework::GradVarName("Y"), {Input("X")}}},
{{framework::GradVarName("X"), Outputs("Y")}}, Attrs());
stack_grad_op->Run(scope, place);
}
VecXType x_;
T *y_;
int n_;
int post_;
};
class UnStackOpGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
"Number of Inputs(Y@Grad) must be larger than 0");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) must exist.");
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(Y@Grad) must be the same");
}
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE(
axis >= -(rank + 1) && axis < rank + 1,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
if (axis < 0) axis += (rank + 1);
template <typename VecDxType, typename T>
struct StackGradFunctor {
HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post)
: dx_(dx), dy_(dy), n_(n), post_(post) {}
auto vec = framework::vectorize<int>(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec));
HOSTDEVICE void operator()(int idx) {
int i = idx / (n_ * post_);
int which_x = idx / post_ - i * n_;
int x_index = i * post_ + idx % post_;
dx_[which_x][x_index] = dy_[idx];
}
private:
VecDxType dx_;
const T *dy_;
int n_;
int post_;
};
class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker {
template <typename DeviceContext, typename VecXType, typename T>
static inline void StackFunctorForRange(const DeviceContext &ctx,
const VecXType &x, T *y, int total_num,
int n, int post) {
platform::ForRange<DeviceContext> for_range(ctx, total_num);
for_range(StackFunctor<VecXType, T>(x, y, n, post));
}
template <typename DeviceContext, typename VecDxType, typename T>
static inline void StackGradFunctorForRange(const DeviceContext &ctx,
const VecDxType &dx, const T *dy,
int total_num, int n, int post) {
platform::ForRange<DeviceContext> for_range(ctx, total_num);
for_range(StackGradFunctor<VecDxType, T>(dx, dy, n, post));
}
template <typename DeviceContext, typename T>
class UnStackGradKernel : public framework::OpKernel<T> {
using Tensor = framework::LoDTensor;
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("unstack_grad");
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
void Compute(const framework::ExecutionContext &ctx) const override {
auto x = ctx.MultiInput<Tensor>(framework::GradVarName("Y"));
auto *y = ctx.Output<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = static_cast<int>(x.size());
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
std::vector<const T *> x_datas(n);
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();
int pre = 1;
int post = 1;
auto &dim = x[0]->dims();
for (auto i = 0; i < axis; ++i) pre *= dim[i];
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
#ifdef __NVCC__
int total_num = pre * n * post;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
thrust::device_vector<const T *> device_x_vec(x_datas);
auto x_data_arr = device_x_vec.data().get();
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
// Wait() must be called because device_x_vec may be destructed before
// kernel ends
dev_ctx.Wait();
#else
auto x_data_arr = x_datas.data();
size_t x_offset = 0;
size_t y_offset = 0;
for (int i = 0; i < pre; i++) {
for (int j = 0; j < n; j++) {
std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset,
post * sizeof(T));
y_offset += post;
}
x_offset += post;
}
#endif
}
};
class UnStackGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
template <typename DeviceContext, typename T>
class UnStackKernel : public framework::OpKernel<T> {
using Tensor = framework::LoDTensor;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto stack_op = framework::OpRegistry::CreateOp(
"stack", {{"X", Inputs(framework::GradVarName("Y"))}},
{{"Y", {Output(framework::GradVarName("X"))}}}, Attrs());
stack_op->Run(scope, place);
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *dy = ctx.Input<Tensor>("X");
auto dx = ctx.MultiOutput<Tensor>("Y");
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += dy->dims().size();
int n = dy->dims()[axis];
std::vector<T *> dx_datas(n); // NOLINT
for (int i = 0; i < n; i++) {
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
}
auto dy_data = dy->data<T>();
int pre = 1;
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
int total_num = dy->numel();
int post = total_num / (n * pre);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
#ifdef __NVCC__
thrust::device_vector<T *> device_dx_vec(dx_datas);
auto dx_data_arr = device_dx_vec.data().get();
#else
auto dx_data_arr = dx_datas.data();
#endif
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
#ifdef __NVCC__
// Wait() must be called because device_dx_vec may be destructed before
// kernel ends
dev_ctx.Wait();
#endif
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册