未验证 提交 1854814d 编写于 作者: Y yuyang18

Use reshape_op inside squeeze_op

* also convert tab to space
上级 3777f102
...@@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor) ...@@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor) op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(squeeze_op DEPS reshape_op)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
...@@ -12,33 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,33 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/squeeze_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::OpKernelType; class SqueezeOpInferShape : public framework::InferShapeBase {
using framework::Tensor;
class SqueezeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void operator()(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SqueezeOp should not be null."); "Input(X) of SqueezeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SqueezeOp should not be null."); "Output(Out) of SqueezeOp should not be null.");
const auto& x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<9). // Check input tensor dims (<9).
PADDLE_ENFORCE(x_dims.size() <= 9, PADDLE_ENFORCE(x_dims.size() <= 9,
"Invalid dimnesions, dynamic dimensions must have " "Invalid dimnesions, dynamic dimensions must have "
"between [1, 9] dimensions."); "between [1, 9] dimensions.");
const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) { for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(), PADDLE_ENFORCE_LT(a, x_dims.size(),
"The axis must be less than input tensor's rank."); "The axis must be less than input tensor's rank.");
...@@ -55,7 +50,7 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -55,7 +50,7 @@ class SqueezeOp : public framework::OperatorWithKernel {
} }
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims, static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim& in_dims) { const framework::DDim &in_dims) {
int num_squeeze_dims = squeeze_dims.size(); int num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0; int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false}; bool should_squeeze[9] = {false};
...@@ -100,6 +95,31 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -100,6 +95,31 @@ class SqueezeOp : public framework::OperatorWithKernel {
} }
}; };
class SqueezeOp : public framework::OperatorBase {
public:
SqueezeOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
attrs["inplace"] = Attr<bool>("inplace");
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -141,42 +161,48 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -141,42 +161,48 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class SqueezeGradOp : public framework::OperatorWithKernel { class SqueezeGradInferShape : public framework::InferShapeBase {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim(framework::GradVarName("X"),
void InferShape(framework::InferShapeContext* ctx) const override { context->GetInputDim("X"));
PADDLE_ENFORCE(ctx->HasInput("X"), context->ShareLoD("X", framework::GradVarName("X"));
"Input(X) of SqueezeGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Output(Out@GRAD) of SqueezeGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
};
protected: class SqueezeGradOp : public framework::OperatorBase {
framework::OpKernelType GetExpectedKernelType( public:
const framework::ExecutionContext& ctx) const override { SqueezeGradOp(const std::string &type,
return framework::OpKernelType( const framework::VariableNameMap &inputs,
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), const framework::VariableNameMap &outputs,
ctx.device_context()); const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = Attr<bool>("inplace");
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Tell linker to use reshape op
USE_OP(reshape);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp); REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
REGISTER_OP_CPU_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/squeeze_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SqueezeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X");
framework::DDim out_dims = out->dims();
bool inplace = ctx.Attr<bool>("inplace");
out->Resize(out_dims);
if (!inplace) {
out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopySync(*in, ctx.GetPlace(), out);
out->Resize(out_dims);
} else {
out->ShareDataWith(*in);
out->Resize(out_dims);
}
}
};
template <typename DeviceContext, typename T>
class SqueezeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
bool inplace = ctx.Attr<bool>("inplace");
auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
ctx.device_context().Wait();
d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);
d_x->Resize(in_dims);
}
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册