提交 5f627488 编写于 作者: Z zhongpu 提交者: Jiabin Yang

add kernel for unsqueeze_op and Add unsqueezed op test, test=develop (#19436)

* add kernel for unsqueeze_op, test=develop

* add kernel for unsqueeze_op, test=develop

* add kernel for unsqueeze_op, test=develop
上级 a7691603
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/unsqueeze_op.h"
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -19,20 +21,22 @@ limitations under the License. */ ...@@ -19,20 +21,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class UnsqueezeOpInferShape : public framework::InferShapeBase { class UnsqueezeOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { using framework::OperatorWithKernel::OperatorWithKernel;
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of Unsqueeze operator should not be null."); void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Output(Out) of Unsqueeze operator should not be null."); "Input(X) of Unsqueeze operator should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Unsqueeze operator should not be null.");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6). // Validity Check: input tensor dims (<6).
PADDLE_ENFORCE(x_dims.size() <= 6, PADDLE_ENFORCE_LE(x_dims.size(), 6,
"Invalid dimensions, the rank of Input(X) " "Invalid dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)"); "should be in the range of [1, 6] (Eigen limit)");
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
...@@ -49,15 +53,14 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { ...@@ -49,15 +53,14 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
std::vector<int64_t> output_shape(output_size, 0); std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range. // Validity Check: rank range.
PADDLE_ENFORCE(output_size <= 6, PADDLE_ENFORCE_LE(output_size, 6,
"The output tensor's rank should be less than 6."); "The output tensor's rank should be less than 6.");
for (int axis : unsqz_dims) { for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis; int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Vaildity Check: the axis bound // Vaildity Check: the axis bound
PADDLE_ENFORCE( PADDLE_ENFORCE_GE(cur, 0);
cur >= 0 && cur <= cur_output_size, PADDLE_ENFORCE_LE(cur, cur_output_size);
"The unsqueeze dims must be within range of current rank.");
// Move old axis, and insert new axis // Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) { for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) { if (output_shape[i] == 1) {
...@@ -82,27 +85,6 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { ...@@ -82,27 +85,6 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
} }
}; };
class UnsqueezeOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
// Invoke Reshape op.
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -112,17 +94,17 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -112,17 +94,17 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
"(std::vector<int>). List of integers," "(std::vector<int>). List of integers,"
" indicating the dimensions to be inserted") " indicating the dimensions to be inserted")
.AddCustomChecker([](const std::vector<int> &axes) { .AddCustomChecker([](const std::vector<int> &axes) {
PADDLE_ENFORCE(!axes.empty(), PADDLE_ENFORCE_EQ(!axes.empty(), true,
"Invalid axes, The unsqueeze axes is empty."); "Invalid axes, The unsqueeze axes is empty.");
// Validity Check: axes dims (<6). // Validity Check: axes dims (<6).
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6, PADDLE_ENFORCE_LT(static_cast<int>(axes.size()), 6,
"Invalid dimensions, dynamic dimensions should be " "Invalid dimensions, dynamic dimensions should be "
"within [1, 6] dimensions (Eigen limit)."); "within [1, 6] dimensions (Eigen limit).");
// Validity Check: the range of unsqueeze aixs. // Validity Check: the range of unsqueeze aixs.
for (int axis : axes) { for (int axis : axes) {
PADDLE_ENFORCE(axis < 6, PADDLE_ENFORCE_LT(axis, 6,
"Invalid dimensions, input axis should be" "Invalid dimensions, input axis should be"
" within [1, 6] dimensions (Eigen limit)."); " within [1, 6] dimensions (Eigen limit).");
} }
}); });
AddComment(R"DOC( AddComment(R"DOC(
...@@ -139,47 +121,47 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -139,47 +121,47 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class UnsqueezeGradInferShape : public framework::InferShapeBase { class UnsqueezeGradOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
}; };
class UnsqueezeGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
}
};
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on // FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
// unsqueeze, the XShape is used to carry the shape and lod of X which // unsqueeze, the XShape is used to carry the shape and lod of X which
// will be used in unsqueeze_grad, in this way, the framework can reuse // will be used in unsqueeze_grad, in this way, the framework can reuse
// the memory of X immediately the unsqueeze2_op is finished. // the memory of X immediately the unsqueeze2_op is finished.
// Considering compatibility issues, we could not fix unsqueeze2_op // Considering compatibility issues, we could not fix unsqueeze2_op
class Unsqueeze2OpInferShape : public UnsqueezeOpInferShape { class Unsqueeze2Op : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { using framework::OperatorWithKernel::OperatorWithKernel;
UnsqueezeOpInferShape::operator()(ctx); void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("XShape"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Output(XShape) of Unsqueeze operator should not be null."); "Input(X) of Unsqueeze operator should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Unsqueeze operator should not be null.");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE_LE(x_dims.size(), 6,
"Invalid dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)");
auto out_dims = UnsqueezeOp::GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
PADDLE_ENFORCE_EQ(
ctx->HasOutput("XShape"), true,
"Output(XShape) of Unsqueeze operator should not be null.");
std::vector<int64_t> xshape_dims(x_dims.size() + 1); std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0; xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) { for (int i = 0; i < x_dims.size(); ++i) {
...@@ -201,27 +183,6 @@ class Unsqueeze2OpMaker : public UnsqueezeOpMaker { ...@@ -201,27 +183,6 @@ class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
} }
}; };
class Unsqueeze2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Unsqueeze2OpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
// Invoke Reshape op.
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker { class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
...@@ -237,43 +198,26 @@ class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker { ...@@ -237,43 +198,26 @@ class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class Unsqueeze2GradInferShape : public framework::InferShapeBase { class Unsqueeze2GradOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext *context) const override { using framework::OperatorWithKernel::OperatorWithKernel;
PADDLE_ENFORCE(context->HasInput("XShape"), void InferShape(framework::InferShapeContext *context) const override {
"Input(XShape) shouldn't be null."); PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true,
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")), "Input(XShape) shouldn't be null.");
"Input(Out@GRAD) shouldn't be null."); PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape"); auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims); context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X")); context->ShareLoD("XShape", framework::GradVarName("X"));
} }
};
class Unsqueeze2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp( protected:
"reshape2_grad", {{framework::GradVarName("Out"), {dout_name}}, framework::OpKernelType GetExpectedKernelType(
{"Shape", {}}, const framework::ExecutionContext &ctx) const override {
{"XShape", {xshape_name}}}, return framework::OpKernelType(
{{framework::GradVarName("X"), {dx_name}}}, attrs); ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
reshape_op->Run(scope, place); ctx.device_context());
} }
}; };
...@@ -281,23 +225,43 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"}); ...@@ -281,23 +225,43 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Tell linker to use reshape op.
USE_OP(reshape);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp);
ops::UnsqueezeGradInferShape);
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker, REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2OpInferShape, ops::Unsqueeze2GradOpMaker, ops::Unsqueeze2GradOpMaker, ops::UnsqueezeInplaceInferer);
ops::UnsqueezeInplaceInferer);
REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp, REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
ops::Unsqueeze2GradInferShape,
ops::UnsqueezeGradInplaceInferer); ops::UnsqueezeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2,
ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unsqueeze_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze2,
ops::Unsqueeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Unsqueeze2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class UnsqueezeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto &axes = context.Attr<std::vector<int>>("axes");
auto *in = context.Input<framework::LoDTensor>("X");
auto *out = context.Output<framework::LoDTensor>("Out");
auto x_dims = in->dims();
auto out_dims = GetOutputShape(axes, x_dims);
out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
*in, context.GetPlace(),
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
const framework::DDim &in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range.
PADDLE_ENFORCE_LE(output_size, 6,
"The output tensor's rank should be less than 6.");
for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE_GE(cur, 0);
PADDLE_ENFORCE_LE(cur, cur_output_size);
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
output_shape[i] = 0;
}
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
}
// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
}
return framework::make_ddim(output_shape);
}
};
template <typename DeviceContext, typename T>
class UnsqueezeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto in_dims = ctx.Input<framework::LoDTensor>("X")->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
d_x->Resize(in_dims);
}
};
template <typename DeviceContext, typename T>
class Unsqueeze2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out");
auto *in = context.Input<framework::LoDTensor>("X");
auto &axes = context.Attr<std::vector<int>>("axes");
auto x_dims = in->dims();
auto out_dims =
UnsqueezeKernel<DeviceContext, T>::GetOutputShape(axes, x_dims);
out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
*in, context.GetPlace(),
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
};
template <typename DeviceContext, typename T>
class Unsqueeze2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
// auto in_dims = d_x->dims();
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
d_x->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
# Correct: General.
class TestUnsqueezeOp(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = "unsqueeze2"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
self.check_output(no_check_set=["XShape"])
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (-1, )
self.new_shape = (3, 5, 1)
# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, -1)
self.new_shape = (1, 3, 5, 1)
# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1)
if __name__ == "__main__":
unittest.main()
...@@ -24,16 +24,13 @@ from op_test import OpTest ...@@ -24,16 +24,13 @@ from op_test import OpTest
class TestUnsqueezeOp(OpTest): class TestUnsqueezeOp(OpTest):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
self.op_type = "unsqueeze2" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs() self.init_attrs()
self.outputs = { self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=["XShape"]) self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册