未验证 提交 f5342918 编写于 作者: W Wang Xin 提交者: GitHub

static graph autogen code for shape op (#54221)

* static graph autogen code for shape op

* fix onednn

* fix onednn
上级 c642aa17
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ShapeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
protected:
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const phi::KernelKey &expected_kernel_type) const override {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
tensor.layout(),
expected_kernel_type.dtype());
}
};
class ShapeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "(phi::DenseTensor), The input tensor.");
AddOutput("Out",
"(phi::DenseTensor), The shape of input tensor, the data type of "
"the shape"
" is int32_t, will be on the same device with the input Tensor.");
AddComment(R"DOC(
Shape Operator.
Return the shape of the input.
)DOC");
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ShapeNoNeedBufferVarsInferer, "Input");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(shape,
ShapeInferShapeFunctor,
PD_INFER_META(phi::ShapeInferMeta));
REGISTER_OPERATOR(
shape,
ops::ShapeOp,
ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::ShapeNoNeedBufferVarsInferer,
ShapeInferShapeFunctor);
......@@ -903,17 +903,6 @@
intermediate : noise
backward : rrelu_grad
- op : shape
args : (Tensor input)
output : Tensor(out)
infer_meta :
func : ShapeInferMeta
kernel :
func : shape {dense -> dense},
shape_sr {selected_rows -> selected_rows}
data_transform:
skip_transform : input
- op : slice
args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
......
......@@ -2135,6 +2135,12 @@
extra :
attrs : [bool use_mkldnn=false]
- op : shape
inputs :
input : Input
outputs :
out : Out
- op : shape
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
......
......@@ -1959,6 +1959,17 @@
optional : master_param, master_param_out
inplace : (param -> param_out), (master_param -> master_param_out)
- op : shape
args : (Tensor input)
output : Tensor(out)
infer_meta :
func : ShapeInferMeta
kernel :
func : shape {dense -> dense},
shape_sr {selected_rows -> selected_rows}
data_transform:
skip_transform : input
- op : shard_index
args : (Tensor input, int index_num, int nshards, int shard_id, int ignore_value=-1)
output : Tensor(out)
......
......@@ -56,4 +56,6 @@ PD_REGISTER_KERNEL(shape,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
uint8_t) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -42,7 +42,6 @@ cc_test_old(
recurrent_op
op_registry
pool_op
shape_op
crop_op
activation_op
generated_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册