未验证 提交 1aa67778 编写于 作者: A Aurelius84 提交者: GitHub

[Phi] Migrate unfold_op into phi (#39778)

* [Phi] Migrate unfold_op into phi

* fix im2col CPUContext template instantial

* fix unfold_op.h header include problem

* fix unittest

* fix PT->PD
上级 60fc555e
...@@ -442,7 +442,9 @@ void BuildDygraphPtenKernelContext( ...@@ -442,7 +442,9 @@ void BuildDygraphPtenKernelContext(
vector_int_attr.end()); vector_int_attr.end());
kernel_ctx->EmplaceBackAttr(vector_int64_attr); kernel_ctx->EmplaceBackAttr(vector_int64_attr);
} }
// TODO(YuanRisheng) Need support vector<int64_t> attr } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct " "Unsupported cast op attribute `%s` when construct "
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/fold_op.h" #include "paddle/fluid/operators/fold_op.h"
#include "paddle/fluid/operators/unfold_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -22,6 +22,10 @@ class CPUDeviceContext; ...@@ -22,6 +22,10 @@ class CPUDeviceContext;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
namespace phi {
class CPUContext;
} // namespace phi
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -31,12 +35,12 @@ namespace math { ...@@ -31,12 +35,12 @@ namespace math {
* col = * col =
* [input_channels, filter_height, filter_width, output_height, output_width] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class T, typename DeviceContext>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
platform::CPUDeviceContext, T> { T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& im,
const framework::Tensor& im, const std::vector<int>& dilation, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col, const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) { const DataLayout data_layout) {
...@@ -73,12 +77,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -73,12 +77,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
* col = * col =
* [input_channels, filter_height, filter_width, output_height, output_width] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class T, typename DeviceContext>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
platform::CPUDeviceContext, T> { T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& col,
const framework::Tensor& col,
const std::vector<int>& dilation, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im, const std::vector<int>& padding, framework::Tensor* im,
...@@ -155,22 +158,30 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -155,22 +158,30 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUDeviceContext, float>; platform::CPUDeviceContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUDeviceContext, double>; platform::CPUDeviceContext, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
phi::CPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
phi::CPUContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUDeviceContext, float>; platform::CPUDeviceContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUDeviceContext, double>; platform::CPUDeviceContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
phi::CPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
phi::CPUContext, double>;
/* /*
* im = [input_channels, input_height, input_width] * im = [input_channels, input_height, input_width]
* col = * col =
* [output_height, output_width, input_channels, filter_height, filter_width] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class T, typename DeviceContext>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
platform::CPUDeviceContext, T> { T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& im,
const framework::Tensor& im, const std::vector<int>& dilation, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col, const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) { const DataLayout data_layout) {
...@@ -235,12 +246,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -235,12 +246,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
* col = * col =
* [output_height, output_width, input_channels, filter_height, filter_width] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class T, typename DeviceContext>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
platform::CPUDeviceContext, T> { T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& col,
const framework::Tensor& col,
const std::vector<int>& dilation, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im, const std::vector<int>& padding, framework::Tensor* im,
...@@ -316,11 +326,18 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -316,11 +326,18 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUDeviceContext, float>; platform::CPUDeviceContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUDeviceContext, double>; platform::CPUDeviceContext, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
phi::CPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
phi::CPUContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUDeviceContext, float>; platform::CPUDeviceContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUDeviceContext, double>; platform::CPUDeviceContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
phi::CPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
phi::CPUContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/unfold_op.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -60,126 +62,6 @@ feature map, a series of such columns will be formed. ...@@ -60,126 +62,6 @@ feature map, a series of such columns will be formed.
class UnfoldOp : public framework::OperatorWithKernel { class UnfoldOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of UnfoldOp should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Y"), true,
platform::errors::NotFound("Output(Y) of UnfoldOp should not be null"));
auto in_dims = ctx->GetInputDim("X");
std::vector<int> kernel_sizes =
ctx->Attrs().Get<std::vector<int>>("kernel_sizes");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
// Only [N, C, H, W] input supported now
PADDLE_ENFORCE_EQ(
in_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be 4-D tensor of format [N, C, H, W], but get %u",
in_dims.size()));
PADDLE_ENFORCE_EQ(
in_dims.size() - kernel_sizes.size(), 2U,
platform::errors::InvalidArgument(
"The dims of X should be larger than that of kernel_sizes "
"by a number of 2, due to the batch size and input channel dim. "
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
in_dims.size(), kernel_sizes.size()));
PADDLE_ENFORCE_EQ(
strides.size(), kernel_sizes.size(),
platform::errors::InvalidArgument(
"The dims of strides should be the same with that of kernel_sizes. "
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
strides.size(), kernel_sizes.size()));
PADDLE_ENFORCE_EQ(
paddings.size(), 2 * strides.size(),
platform::errors::InvalidArgument(
"The dims of paddings should be 2 times of that of strides. "
"But recieved dims(paddings: %u) != 2*dims(strides: %u).",
paddings.size(), strides.size()));
PADDLE_ENFORCE_EQ(
strides.size(), dilations.size(),
platform::errors::InvalidArgument(
"The dims of strides should be the same with that of dilations. "
"But recieved dims(strides: %u) != dims(dilations: %u).",
strides.size(), dilations.size()));
// check kernel_sizes
PADDLE_ENFORCE_GT(kernel_sizes[0], 0,
platform::errors::InvalidArgument(
"The `kernel_sizes` should be greater than zero, "
"but recieved kernel_height: %d kernel_width: %d.",
kernel_sizes[0], kernel_sizes[1]));
PADDLE_ENFORCE_GT(kernel_sizes[1], 0,
platform::errors::InvalidArgument(
"The `kernel_sizes` should be greater than zero, "
"but recieved kernel_height: %d kernel_width: %d.",
kernel_sizes[0], kernel_sizes[1]));
// check strides
PADDLE_ENFORCE_GT(strides[0], 0,
platform::errors::InvalidArgument(
"The `strides` should be greater than zero, "
"but recieved strides_height: %d strides_width: %d.",
strides[0], strides[1]));
PADDLE_ENFORCE_GT(strides[1], 0,
platform::errors::InvalidArgument(
"The `strides` should be greater than zero, "
"but recieved strides_height: %d strides_width: %d.",
strides[0], strides[1]));
// check dilations
PADDLE_ENFORCE_GT(
dilations[0], 0,
platform::errors::InvalidArgument(
"The `dilations` should be greater than zero, "
"but recieved dilations_height: %d dilations_width: %d.",
dilations[0], dilations[1]));
PADDLE_ENFORCE_GT(
dilations[1], 0,
platform::errors::InvalidArgument(
"The `dilations` should be greater than zero, "
"but recieved dilations_height: %d dilations_width: %d.",
dilations[0], dilations[1]));
std::vector<int> out_dims;
out_dims.push_back(in_dims[0]);
int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
out_dims.push_back(output_channels);
int output_height =
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
paddings[2], strides[0]);
int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1],
paddings[1], paddings[3], strides[1]);
if (ctx->IsRuntime()) {
// only check output height and width in runtime
PADDLE_ENFORCE_GT(
output_height, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
PADDLE_ENFORCE_GT(
output_width, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
}
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);
ctx->SetOutputDim("Y", phi::make_ddim(out_dims));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -237,16 +119,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnfoldGradOpNoNeedBufferVarsInferer, "X"); ...@@ -237,16 +119,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnfoldGradOpNoNeedBufferVarsInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(unfold, UnfoldInferShapeFunctor,
PT_INFER_META(phi::UnfoldInferMeta));
REGISTER_OPERATOR(unfold, ops::UnfoldOp, ops::UnfoldOpMaker, REGISTER_OPERATOR(unfold, ops::UnfoldOp, ops::UnfoldOpMaker,
ops::UnfoldGradMaker<paddle::framework::OpDesc>, ops::UnfoldGradMaker<paddle::framework::OpDesc>,
ops::UnfoldGradMaker<paddle::imperative::OpBase>); ops::UnfoldGradMaker<paddle::imperative::OpBase>,
UnfoldInferShapeFunctor);
REGISTER_OPERATOR(unfold_grad, ops::UnfoldGradOp, REGISTER_OPERATOR(unfold_grad, ops::UnfoldGradOp,
ops::UnfoldGradOpNoNeedBufferVarsInferer); ops::UnfoldGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
unfold, ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
unfold_grad,
ops::UnfoldGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnfoldGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline int CalcOutputSize(int input_size, int filter_size, int dilation,
int padding1, int padding2, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1;
return output_size;
}
template <typename DeviceContext, typename T>
class UnfoldOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("X");
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor* output = ctx.Output<Tensor>("Y");
output->mutable_data<T>(ctx.GetPlace());
std::vector<int> kernel_sizes = ctx.Attr<std::vector<int>>("kernel_sizes");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto input_dims = input->dims();
int output_height =
CalcOutputSize(input_dims[2], kernel_sizes[0], dilations[0],
paddings[0], paddings[2], strides[0]);
int output_width =
CalcOutputSize(input_dims[3], kernel_sizes[1], dilations[1],
paddings[1], paddings[3], strides[1]);
framework::DDim input_shape({input_dims[1], input_dims[2], input_dims[3]});
framework::DDim output_matrix_shape({input_dims[1], kernel_sizes[0],
kernel_sizes[1], output_height,
output_width});
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
im2col(dev_ctx, in_batch, dilations, strides, paddings, &out_batch);
}
}
};
template <typename DeviceContext, typename T>
class UnfoldGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* output_grad = ctx.Input<Tensor>(framework::GradVarName("Y"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>(ctx.GetPlace());
if ((!output_grad) || (!input_grad)) return;
std::vector<int> kernel_sizes = ctx.Attr<std::vector<int>>("kernel_sizes");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input_grad->dims()[0]);
auto input_dims = input_grad->dims();
int output_height =
CalcOutputSize(input_dims[2], kernel_sizes[0], dilations[0],
paddings[0], paddings[2], strides[0]);
int output_width =
CalcOutputSize(input_dims[3], kernel_sizes[1], dilations[1],
paddings[1], paddings[3], strides[1]);
framework::DDim input_shape({input_dims[1], input_dims[2], input_dims[3]});
framework::DDim output_matrix_shape({input_dims[1], kernel_sizes[0],
kernel_sizes[1], output_height,
output_width});
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, input_grad, static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape);
col2im(dev_ctx, out_grad_batch, dilations, strides, paddings,
&in_grad_batch);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
namespace phi { namespace phi {
...@@ -537,6 +538,164 @@ void TraceInferMeta( ...@@ -537,6 +538,164 @@ void TraceInferMeta(
out->set_dims(phi::make_ddim(sizes)); out->set_dims(phi::make_ddim(sizes));
} }
void UnfoldInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_sizes,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
MetaTensor* out,
MetaConfig config) {
auto in_dims = x.dims();
// Only [N, C, H, W] input supported now
PADDLE_ENFORCE_EQ(
in_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be 4-D tensor of format [N, C, H, W], but get %u",
in_dims.size()));
PADDLE_ENFORCE_EQ(
in_dims.size() - kernel_sizes.size(),
2U,
phi::errors::InvalidArgument(
"The dims of X should be larger than that of kernel_sizes "
"by a number of 2, due to the batch size and input channel dim. "
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
in_dims.size(),
kernel_sizes.size()));
PADDLE_ENFORCE_EQ(
strides.size(),
kernel_sizes.size(),
phi::errors::InvalidArgument(
"The dims of strides should be the same with that of kernel_sizes. "
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
strides.size(),
kernel_sizes.size()));
PADDLE_ENFORCE_EQ(
paddings.size(),
2 * strides.size(),
phi::errors::InvalidArgument(
"The dims of paddings should be 2 times of that of strides. "
"But recieved dims(paddings: %u) != 2*dims(strides: %u).",
paddings.size(),
strides.size()));
PADDLE_ENFORCE_EQ(
strides.size(),
dilations.size(),
phi::errors::InvalidArgument(
"The dims of strides should be the same with that of dilations. "
"But recieved dims(strides: %u) != dims(dilations: %u).",
strides.size(),
dilations.size()));
// check kernel_sizes
PADDLE_ENFORCE_GT(kernel_sizes[0],
0,
phi::errors::InvalidArgument(
"The `kernel_sizes` should be greater than zero, "
"but recieved kernel_height: %d kernel_width: %d.",
kernel_sizes[0],
kernel_sizes[1]));
PADDLE_ENFORCE_GT(kernel_sizes[1],
0,
phi::errors::InvalidArgument(
"The `kernel_sizes` should be greater than zero, "
"but recieved kernel_height: %d kernel_width: %d.",
kernel_sizes[0],
kernel_sizes[1]));
// check strides
PADDLE_ENFORCE_GT(strides[0],
0,
phi::errors::InvalidArgument(
"The `strides` should be greater than zero, "
"but recieved strides_height: %d strides_width: %d.",
strides[0],
strides[1]));
PADDLE_ENFORCE_GT(strides[1],
0,
phi::errors::InvalidArgument(
"The `strides` should be greater than zero, "
"but recieved strides_height: %d strides_width: %d.",
strides[0],
strides[1]));
// check dilations
PADDLE_ENFORCE_GT(
dilations[0],
0,
phi::errors::InvalidArgument(
"The `dilations` should be greater than zero, "
"but recieved dilations_height: %d dilations_width: %d.",
dilations[0],
dilations[1]));
PADDLE_ENFORCE_GT(
dilations[1],
0,
phi::errors::InvalidArgument(
"The `dilations` should be greater than zero, "
"but recieved dilations_height: %d dilations_width: %d.",
dilations[0],
dilations[1]));
std::vector<int> out_dims;
out_dims.push_back(in_dims[0]);
int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
out_dims.push_back(output_channels);
int output_height = phi::funcs::CalcOutputSize(in_dims[2],
kernel_sizes[0],
dilations[0],
paddings[0],
paddings[2],
strides[0]);
int output_width = phi::funcs::CalcOutputSize(in_dims[3],
kernel_sizes[1],
dilations[1],
paddings[1],
paddings[3],
strides[1]);
if (config.is_runtime) {
// only check output height and width in runtime
PADDLE_ENFORCE_GT(
output_height,
0,
phi::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2],
in_dims[3],
kernel_sizes[0],
kernel_sizes[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
output_height,
output_width));
PADDLE_ENFORCE_GT(
output_width,
0,
phi::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2],
in_dims[3],
kernel_sizes[0],
kernel_sizes[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
output_height,
output_width));
}
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);
out->set_dims(phi::make_ddim(out_dims));
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
......
...@@ -93,4 +93,11 @@ void SplitInferMeta(const MetaTensor& x_meta, ...@@ -93,4 +93,11 @@ void SplitInferMeta(const MetaTensor& x_meta,
void TraceInferMeta( void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out); const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out);
void UnfoldInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_sizes,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
MetaTensor* out,
MetaConfig config = MetaConfig());
} // namespace phi } // namespace phi
...@@ -10,7 +10,7 @@ add_subdirectory(funcs) ...@@ -10,7 +10,7 @@ add_subdirectory(funcs)
set_property(GLOBAL PROPERTY PTEN_KERNELS "") set_property(GLOBAL PROPERTY PTEN_KERNELS "")
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col)
# remove this dep after removing fluid deps on tensor creation # remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unfold_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unfold_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
unfold_grad, CPU, ALL_LAYOUT, phi::UnfoldGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unfold_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unfold_kernel_impl.h"
PD_REGISTER_KERNEL(unfold, CPU, ALL_LAYOUT, phi::UnfoldKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
//////// CalcOutputSize Functor ///////
inline int CalcOutputSize(int input_size,
int filter_size,
int dilation,
int padding1,
int padding2,
int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1;
return output_size;
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unfold_grad_kernel_impl.h"
#include "paddle/phi/kernels/unfold_grad_kernel.h"
PD_REGISTER_KERNEL(
unfold_grad, GPU, ALL_LAYOUT, phi::UnfoldGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unfold_kernel_impl.h"
#include "paddle/phi/kernels/unfold_kernel.h"
PD_REGISTER_KERNEL(unfold, GPU, ALL_LAYOUT, phi::UnfoldKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
namespace phi {
template <typename T, typename Context>
void UnfoldGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& kernel_sizes,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
if (!x_grad) return;
auto x_dims = x_grad->dims();
const int batch_size = static_cast<int>(x_dims[0]);
int out_height = phi::funcs::CalcOutputSize(x_dims[2],
kernel_sizes[0],
dilations[0],
paddings[0],
paddings[2],
strides[0]);
int out_width = phi::funcs::CalcOutputSize(x_dims[3],
kernel_sizes[1],
dilations[1],
paddings[1],
paddings[3],
strides[1]);
DDim x_shape = make_ddim({x_dims[1], x_dims[2], x_dims[3]});
DDim out_matrix_shape = make_ddim(
{x_dims[1], kernel_sizes[0], kernel_sizes[1], out_height, out_width});
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, x_grad, static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
DenseTensor out_grad_batch =
out_grad.Slice(i, i + 1).Resize(out_matrix_shape);
DenseTensor x_grad_batch = x_grad->Slice(i, i + 1).Resize(x_shape);
col2im(ctx, out_grad_batch, dilations, strides, paddings, &x_grad_batch);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
namespace phi {
template <typename T, typename Context>
void UnfoldKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& kernel_sizes,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
DenseTensor* out) {
const int batch_size = static_cast<int>(x.dims()[0]);
ctx.template Alloc<T>(out);
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
auto x_dims = x.dims();
int out_height = phi::funcs::CalcOutputSize(x_dims[2],
kernel_sizes[0],
dilations[0],
paddings[0],
paddings[2],
strides[0]);
int out_width = phi::funcs::CalcOutputSize(x_dims[3],
kernel_sizes[1],
dilations[1],
paddings[1],
paddings[3],
strides[1]);
DDim x_shape = make_ddim({x_dims[1], x_dims[2], x_dims[3]});
DDim out_matrix_shape = make_ddim(
{x_dims[1], kernel_sizes[0], kernel_sizes[1], out_height, out_width});
for (int i = 0; i < batch_size; i++) {
DenseTensor in_batch = x.Slice(i, i + 1).Resize(x_shape);
DenseTensor out_batch = out->Slice(i, i + 1).Resize(out_matrix_shape);
im2col(ctx, in_batch, dilations, strides, paddings, &out_batch);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void UnfoldGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& kernel_sizes,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void UnfoldKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& kernel_sizes,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
...@@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/unfold_op.h" #include "paddle/phi/core/compat/op_utils.h"
namespace ops = paddle::operators; namespace phi {
REGISTER_OP_CUDA_KERNEL( KernelSignature UnfoldGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
unfold, ops::UnfoldOpKernel<paddle::platform::CUDADeviceContext, float>, return KernelSignature("unfold_grad",
ops::UnfoldOpKernel<paddle::platform::CUDADeviceContext, double>); {"X", GradVarName("Y")},
{"kernel_sizes", "strides", "paddings", "dilations"},
{GradVarName("X")});
}
REGISTER_OP_CUDA_KERNEL( } // namespace phi
unfold_grad,
ops::UnfoldGradOpKernel<paddle::platform::CUDADeviceContext, float>, PD_REGISTER_ARG_MAPPING_FN(unfold_grad, phi::UnfoldGradOpArgumentMapping);
ops::UnfoldGradOpKernel<paddle::platform::CUDADeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册