提交 118bb897 编写于 作者: Z zhongpu 提交者: Jiabin Yang

add kernel for flatten_op, test=develop (#19472)

* add kernel for flatten_op, test=develop

* add kernel for flatten_op, test=develop

* fix the license and remove redundant code, test=develop
上级 0a46d345
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/flatten_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -20,18 +24,21 @@ namespace operators {
using Tensor = framework::Tensor;
class FlattenOpInferShape : public framework::InferShapeBase {
class FlattenOp : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input (X) of Flatten op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output (Output) of Flatten op should not be null.");
const auto &axis = ctx->Attrs().Get<int>("axis");
const auto &in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(axis >= 0, "The axis should be greater than or equal to 0.");
PADDLE_ENFORCE(
axis <= in_dims.size(),
PADDLE_ENFORCE_GE(axis, 0,
"The axis should be greater than or equal to 0.");
PADDLE_ENFORCE_LE(
axis, in_dims.size(),
"The axis should be less than or equal to input tensor's rank.");
const auto &out_dims = GetOutputShape(axis, in_dims);
......@@ -58,28 +65,12 @@ class FlattenOpInferShape : public framework::InferShapeBase {
out_shape[1] = inner;
return out_shape;
}
};
class FlattenOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axis = Attr<int>("axis");
auto in_dims =
scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims);
framework::AttributeMap attrs;
attrs["shape"] = out_dims;
attrs["inplace"] = false;
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......@@ -126,34 +117,21 @@ Case 2:
}
};
class FlattenGradInferShape : public framework::InferShapeBase {
class FlattenGradOp : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext *context) const override {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X"));
context->ShareLoD("X", framework::GradVarName("X"));
}
};
class FlattenGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto in_dims =
scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(in_dims);
attrs["inplace"] = false;
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......@@ -162,13 +140,33 @@ class FlattenGradOp : public framework::OperatorBase {
// flatten_grad, in this way, the framework can reuse the memory of X
// immediately the flatten2_op is finished.
// Considering compatibility issues, we could not fix flatten2_op
class Flatten2OpInferShape : public FlattenOpInferShape {
class Flatten2Op : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext *ctx) const override {
FlattenOpInferShape::operator()(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output (XShape) of Flatten op should not be null.");
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input (X) of Flatten op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output (Output) of Flatten op should not be null.");
const auto &axis = ctx->Attrs().Get<int>("axis");
const auto &in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(axis, 0,
"The axis should be greater than or equal to 0.");
PADDLE_ENFORCE_LE(
axis, in_dims.size(),
"The axis should be less than or equal to input tensor's rank.");
const auto &out_dims = FlattenOp::GetOutputShape(axis, in_dims);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
if (in_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
"Output (XShape) of Flatten op should not be null.");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
......@@ -179,29 +177,6 @@ class Flatten2OpInferShape : public FlattenOpInferShape {
}
};
class Flatten2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axis = Attr<int>("axis");
auto in_dims =
scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims);
framework::AttributeMap attrs;
attrs["shape"] = out_dims;
attrs["inplace"] = false;
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Flatten2OpMaker : public FlattenOpMaker {
public:
void Make() override {
......@@ -228,43 +203,27 @@ class Flatten2GradOpMaker : public framework::SingleGradOpDescMaker {
}
};
class Flatten2GradInferShape : public framework::InferShapeBase {
class Flatten2GradOp : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("XShape"),
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true,
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
};
class Flatten2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = false;
auto reshape_grad_op = framework::OpRegistry::CreateOp(
"reshape2_grad",
{{"Out@GRAD", {dout_name}}, {"Shape", {}}, {"XShape", {xshape_name}}},
{{"X@GRAD", {dx_name}}}, attrs);
reshape_grad_op->Run(scope, place);
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......@@ -276,18 +235,41 @@ DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceinToOut,
} // namespace operators
} // namespace paddle
USE_OP(reshape);
namespace ops = paddle::operators;
REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
ops::FlattenOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>,
ops::FlattenOpInplaceInToOut);
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape,
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp,
ops::FlattenGradInplaceinToOut);
REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker,
ops::FlattenOpInplaceInToOut);
ops::Flatten2GradOpMaker, ops::FlattenOpInplaceInToOut);
REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
ops::Flatten2GradInferShape, ops::FlattenGradInplaceinToOut);
ops::FlattenGradInplaceinToOut);
REGISTER_OP_CPU_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CPUDeviceContext, float>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, int>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten_grad,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten2, ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten2_grad,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/flatten_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CUDADeviceContext, float>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, double>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten_grad,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten2, ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten2_grad,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FlattenKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *in = context.Input<framework::LoDTensor>("X");
auto *out = context.Output<framework::LoDTensor>("Out");
auto &axes = context.Attr<int>("axis");
auto x_dims = in->dims();
auto out_dims = framework::make_ddim(GetOutputShape(axes, x_dims));
out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
*in, context.GetPlace(),
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
static std::vector<int32_t> GetOutputShape(const int axis,
const framework::DDim &in_dims) {
int64_t outer = 1, inner = 1;
for (int i = 0; i < in_dims.size(); ++i) {
if (i < axis) {
outer *= in_dims[i];
} else {
inner *= in_dims[i];
}
}
std::vector<int32_t> out_shape(2);
out_shape[0] = outer;
out_shape[1] = inner;
return out_shape;
}
};
template <typename DeviceContext, typename T>
class FlattenGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto in_dims = ctx.Input<framework::LoDTensor>("X")->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
d_x->Resize(in_dims);
}
};
template <typename DeviceContext, typename T>
class Flatten2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto &axes = context.Attr<int>("axis");
auto *in = context.Input<framework::LoDTensor>("X");
auto x_dims = in->dims();
auto *out = context.Output<framework::LoDTensor>("Out");
auto out_dims = framework::make_ddim(
FlattenKernel<DeviceContext, T>::GetOutputShape(axes, x_dims));
out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
*in, context.GetPlace(),
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
};
template <typename DeviceContext, typename T>
class Flatten2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
d_x->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestFlattenOp(OpTest):
def setUp(self):
self.op_type = "flatten2"
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32")
}
def test_check_output(self):
self.check_output(no_check_set=["XShape"])
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.in_shape = (3, 2, 2, 5)
self.axis = 1
self.new_shape = (3, 20)
def init_attrs(self):
self.attrs = {"axis": self.axis}
class TestFlattenOp(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 2, 3)
self.axis = 0
self.new_shape = (1, 36)
class TestFlattenOpWithDefaultAxis(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 2, 3)
self.new_shape = (3, 12)
def init_attrs(self):
self.attrs = {}
class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
self.axis = 4
self.new_shape = (36, 16)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,17 +22,14 @@ from op_test import OpTest
class TestFlattenOp(OpTest):
def setUp(self):
self.op_type = "flatten2"
self.op_type = "flatten"
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32")
}
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
self.check_output(no_check_set=["XShape"])
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册