未验证 提交 e531bb02 编写于 作者: H huangjiyi 提交者: GitHub

Support static graph code generation for op strided_slice (#54098)

* update

* update

* update

* update

* update

* update

* update

* update
上级 7f696804
...@@ -33,7 +33,7 @@ from type_mapping import ( ...@@ -33,7 +33,7 @@ from type_mapping import (
def get_infer_var_type_func(op_name): def get_infer_var_type_func(op_name):
if op_name == "assign": if op_name == "assign":
return f""" return f"""
class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference {{ class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference {{
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override {{ void operator()(framework::InferVarTypeContext *ctx) const override {{
ctx->SyncTypeAndDataType("X", "Out"); ctx->SyncTypeAndDataType("X", "Out");
...@@ -64,16 +64,37 @@ class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference ...@@ -64,16 +64,37 @@ class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference
""" """
elif op_name == "merge_selected_rows": elif op_name == "merge_selected_rows":
return f""" return f"""
class {to_pascal_case(op_name)}InferVarType class {to_pascal_case(op_name)}InferVarType : public framework::PassInDtypeAndVarTypeToOutput {{
: public framework::PassInDtypeAndVarTypeToOutput {{
protected: protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType() const override {{
const override {{
static std::unordered_map<std::string, std::string> m{{{{"X", /*->*/ "Out"}}}}; static std::unordered_map<std::string, std::string> m{{{{"X", /*->*/ "Out"}}}};
return m; return m;
}} }}
}}; }};
""" """
elif op_name == "strided_slice":
return f"""
class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference {{
public:
void operator()(framework::InferVarTypeContext *ctx) const override {{
ctx->SetOutputType("Out", ctx->GetInputType("Input"));
ctx->SetOutputDataType("Out", ctx->GetInputDataType("Input"));
}}
}};
"""
elif op_name == "strided_slice_grad":
return f"""
class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference {{
public:
void operator()(framework::InferVarTypeContext *ctx) const override {{
ctx->SetOutputType(framework::GradVarName("Input"),
ctx->GetInputType(framework::GradVarName("Out")));
ctx->SetOutputDataType(
framework::GradVarName("Input"),
ctx->GetInputDataType(framework::GradVarName("Out")));
}}
}};
"""
else: else:
return None return None
......
...@@ -178,6 +178,47 @@ phi::KernelKey GetSoftmaxGradExpectedKernelType( ...@@ -178,6 +178,47 @@ phi::KernelKey GetSoftmaxGradExpectedKernelType(
ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type));
} }
phi::KernelKey GetStridedSliceExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
auto* in_var = ctx.InputVar("Input");
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
if (is_in_var_array) {
auto& tensor_array = in_var->Get<framework::LoDTensorArray>();
for (auto& tensor : tensor_array) {
if (!platform::is_cuda_pinned_place(tensor.place())) {
PADDLE_ENFORCE_EQ(
platform::is_same_place(tensor.place(),
ctx.device_context().GetPlace()),
true,
platform::errors::InvalidArgument(
"Place of context is %s. Place of input tensor is %s. They "
"are should be same, but reveived different place.",
string::to_string(ctx.device_context().GetPlace()),
string::to_string(tensor.place())));
}
}
return phi::KernelKey(op_ptr->IndicateVarDataType(ctx, "Input"),
ctx.GetPlace());
}
// NOTE: cuda pinned tensor need to copy its data to target place
auto in_tensor = ctx.Input<phi::DenseTensor>("Input");
if (platform::is_cuda_pinned_place(in_tensor->place())) {
return phi::KernelKey(framework::TransToProtoVarType(in_tensor->dtype()),
ctx.GetPlace());
}
return phi::KernelKey(op_ptr->IndicateVarDataType(ctx, "Input"),
in_tensor->place());
}
phi::KernelKey GetStridedSliceGradExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
return phi::KernelKey(
op_ptr->IndicateVarDataType(ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
phi::KernelKey GetUpdateLossScalingExpectedKernelType( phi::KernelKey GetUpdateLossScalingExpectedKernelType(
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) { const framework::OperatorWithKernel* op_ptr) {
......
...@@ -44,6 +44,14 @@ phi::KernelKey GetSoftmaxGradExpectedKernelType( ...@@ -44,6 +44,14 @@ phi::KernelKey GetSoftmaxGradExpectedKernelType(
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr); const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetStridedSliceExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetStridedSliceGradExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetUpdateLossScalingExpectedKernelType( phi::KernelKey GetUpdateLossScalingExpectedKernelType(
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr); const framework::OperatorWithKernel* op_ptr);
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace paddle {
namespace operators {
class StridedSliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *in_var = ctx.InputVar("Input");
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
if (is_in_var_array) {
auto &tensor_array = in_var->Get<framework::LoDTensorArray>();
for (auto &tensor : tensor_array) {
if (!platform::is_cuda_pinned_place(tensor.place())) {
PADDLE_ENFORCE_EQ(
platform::is_same_place(tensor.place(),
ctx.device_context().GetPlace()),
true,
platform::errors::InvalidArgument(
"Place of context is %s. Place of input tensor is %s. They "
"are should be same, but reveived different place.",
string::to_string(ctx.device_context().GetPlace()),
string::to_string(tensor.place())));
}
}
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace());
}
// NOTE: cuda pinned tensor need to copy its data to target place
auto in_tensor = ctx.Input<phi::DenseTensor>("Input");
if (platform::is_cuda_pinned_place(in_tensor->place())) {
return phi::KernelKey(framework::TransToProtoVarType(in_tensor->dtype()),
ctx.GetPlace());
}
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
in_tensor->place());
}
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const phi::KernelKey &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
var_name == "StridesTensor") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StridesTensorList") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
class StridedSliceOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType("Out", ctx->GetInputType("Input"));
ctx->SetOutputDataType("Out", ctx->GetInputDataType("Input"));
}
};
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "Tensor of data to extract slices from.");
AddOutput("Out", "Strided Sliced data tensor.");
AddInput("StartsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StartsTensor, StartsTensorList "
"and attr(starts).")
.AsDispensable();
AddInput("EndsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of EndsTensor, EndsTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StridesTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StridesTensor, StridesTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StartsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(starts).")
.AsDuplicable()
.AsDispensable();
AddInput(
"EndsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(ends).")
.AsDuplicable()
.AsDispensable();
AddInput(
"StridesTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(strides).")
.AsDuplicable()
.AsDispensable();
AddAttr<std::vector<int>>(
"axes", "(list<int>) Axes that `starts` and `ends` apply to.");
AddAttr<std::vector<int>>(
"starts", "(list<int>) Start indices for the strided slice start.")
.SetDefault({});
AddAttr<std::vector<int>>("ends",
"(list<int>) End indices the tensor slice end")
.SetDefault({});
AddAttr<std::vector<int>>(
"strides", "(list<int> Stride step from the start to the end)")
.SetDefault({});
AddAttr<std::vector<int>>(
"infer_flags", "(list<int>) Flags of inferring dims in attributes.")
.SetDefault({});
AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
.SetDefault({});
AddComment(R"DOC(
Strided Slice Operator.
Instead of calling this op directly most users will want to use the
NumPy-style slicing syntax.
For Example:
data = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='int64')
y = fluid.layers.strided_slice(data, [0, 1], [1,0], [2, 3], [1, 1])
)DOC");
}
};
class StridedSliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const phi::KernelKey &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
var_name == "StridesTensor") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StridesTensorList") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
template <typename T>
class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> bind) const override {
bind->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
bind->SetInput("Input", this->Input("Input"));
bind->SetInput("StartsTensor", this->Input("StartsTensor"));
bind->SetInput("EndsTensor", this->Input("EndsTensor"));
bind->SetInput("StridesTensor", this->Input("StridesTensor"));
bind->SetInput("StartsTensorList", this->Input("StartsTensorList"));
bind->SetInput("EndsTensorList", this->Input("EndsTensorList"));
bind->SetInput("StridesTensorList", this->Input("StridesTensorList"));
bind->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
bind->SetAttrMap(this->Attrs());
bind->SetType("strided_slice_grad");
}
};
class StridedSliceGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType(framework::GradVarName("Input"),
ctx->GetInputType(framework::GradVarName("Out")));
ctx->SetOutputDataType(
framework::GradVarName("Input"),
ctx->GetInputDataType(framework::GradVarName("Out")));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
"Input");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice,
StridedSliceInferShape,
PD_INFER_META(phi::StridedSliceRawInferMeta));
REGISTER_OPERATOR(strided_slice,
ops::StridedSliceOp,
ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
ops::StridedSliceOpVarTypeInference,
StridedSliceInferShape);
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice_grad,
StridedSliceGradInferShape,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(strided_slice_grad,
ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInferer,
ops::StridedSliceGradOpVarTypeInference,
StridedSliceGradInferShape);
...@@ -2267,6 +2267,30 @@ ...@@ -2267,6 +2267,30 @@
outputs : outputs :
out : Out out : Out
- op : strided_slice
backward : strided_slice_grad
inputs :
x : Input
outputs :
out : Out
int_array :
starts :
data_type : int
tensor_name : StartsTensor
tensors_name : StartsTensorList
ends :
data_type : int
tensor_name : EndsTensor
tensors_name : EndsTensorList
strides :
data_type : int
tensor_name : StridesTensor
tensors_name : StridesTensorList
manual_signature : [strided_slice, strided_slice_grad]
get_expected_kernel_type :
strided_slice : GetStridedSliceExpectedKernelType
strided_slice_grad : GetStridedSliceGradExpectedKernelType
- op : subtract (elementwise_sub) - op : subtract (elementwise_sub)
backward : subtract_grad (elementwise_sub_grad) backward : subtract_grad (elementwise_sub_grad)
inputs : inputs :
......
...@@ -76,6 +76,19 @@ ...@@ -76,6 +76,19 @@
func : softmax_grad func : softmax_grad
composite : softmax_grad(out, out_grad, axis, x_grad) composite : softmax_grad(out, out_grad, axis, x_grad)
- backward_op : strided_slice_grad
forward : strided_slice (Tensor x, int[] axes, IntArray starts={}, IntArray ends={}, IntArray strides={}, int[] infer_flags={}, int[] decrease_axis={}) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int[] axes, IntArray starts, IntArray ends, IntArray strides, int[] infer_flags, int[] decrease_axis)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
kernel :
func : strided_slice_grad
param : [x, axes, starts, ends, strides]
data_type : out_grad
no_need_buffer : x
- backward_op : tril_triu_grad - backward_op : tril_triu_grad
forward : tril_triu (Tensor x, int diagonal = 0, bool lower = false) -> Tensor(out) forward : tril_triu (Tensor x, int diagonal = 0, bool lower = false) -> Tensor(out)
args : (Tensor out_grad, int diagonal, bool lower) args : (Tensor out_grad, int diagonal, bool lower)
......
...@@ -355,6 +355,16 @@ ...@@ -355,6 +355,16 @@
inplace : (x -> out) inplace : (x -> out)
backward : softmax_grad backward : softmax_grad
- op : strided_slice
args : (Tensor x, int[] axes, IntArray starts={}, IntArray ends={}, IntArray strides={}, int[] infer_flags={}, int[] decrease_axis={})
output : Tensor
infer_meta :
func : StridedSliceRawInferMeta
kernel :
func : strided_slice
param : [x, axes, starts, ends, strides]
backward : strided_slice_grad
- op : tril_indices - op : tril_indices
args : (int rows = 0, int cols = 0, int offset = 0, DataType dtype = DataType::INT64) args : (int rows = 0, int cols = 0, int offset = 0, DataType dtype = DataType::INT64)
output : Tensor(out) output : Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册