未验证 提交 0dccdee0 编写于 作者: C Chen Weihang 提交者: GitHub

remove unchanged infermeta new (#39343)

上级 633c71c2
......@@ -22,22 +22,16 @@
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/unary.h"
namespace paddle {
namespace operators {
class ConjOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "conj");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "conj");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ConjOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -72,9 +66,12 @@ class ConjGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(conj, ConjInferShapeFunctor,
PT_INFER_META(pten::UnchangedInferMeta));
REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker,
ops::ConjGradMaker<paddle::framework::OpDesc>,
ops::ConjGradMaker<paddle::imperative::OpBase>);
ops::ConjGradMaker<paddle::imperative::OpBase>,
ConjInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
conj, ops::ConjKernel<paddle::platform::CPUDeviceContext,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
......@@ -60,7 +61,7 @@ class SignGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(sign, SignInferShapeFunctor,
PT_INFER_META(pten::UnchangedInferMetaNew));
PT_INFER_META(pten::UnchangedInferMeta));
REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
ops::SignGradMaker<paddle::framework::OpDesc>,
ops::SignGradMaker<paddle::imperative::OpBase>,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/pten/kernels/sign_kernel.h"
namespace paddle {
namespace operators {
// See Note [ Why still keep the original kernel implementation? ]
template <typename DeviceContext, typename T>
class SignKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data<T>(x->place());
// call new kernel
pten::SignKernel<T, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -21,16 +21,6 @@ limitations under the License. */
namespace pten {
void UnchangedInferMetaNew(MetaConfig config,
const MetaTensor& x,
MetaTensor* out) {
out->share_meta(x);
}
DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) {
return x_meta;
}
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) {
out->share_meta(x);
}
......@@ -319,4 +309,4 @@ void ReduceInferMeta(const MetaTensor& x,
} // namespace pten
PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMetaNew);
PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMeta);
......@@ -31,11 +31,6 @@ class MetaConfig;
// Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.
// TODO(chenweihang): to avoid conflit, remove this function in next PR
void UnchangedInferMetaNew(MetaConfig config,
const MetaTensor& x,
MetaTensor* out);
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
void FlattenInferMeta(const MetaTensor& x,
......
......@@ -29,7 +29,7 @@ TEST(MetaFunctionMap, InferMetaFnExists) {
pten::MetaTensor meta_x(&dense_x);
pten::DenseTensor dense_out1;
pten::MetaTensor meta_out(&dense_out1);
pten::UnchangedInferMetaNew(/*is_runtime=*/true, meta_x, &meta_out);
pten::UnchangedInferMeta(meta_x, &meta_out);
auto shared_meat_x = std::make_shared<pten::MetaTensor>(&dense_x);
pten::DenseTensor dense_out2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册