未验证 提交 58f08924 编写于 作者: Z zyfncg 提交者: GitHub

Support static graph code-gen for scalar and int_array (#48792)

* add suppport_tensor for code_gen to static graph

* support code-gen for int_array

* polish code

* fix bug of data_type
上级 ff8b2cb7
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
// TODO(freeliuzc): Delete old infershape
// New infershape has already in unary.h and backward.h
namespace paddle {
namespace operators {
class CropTensorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CropTensor");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CropTensor");
auto x_dim = ctx->GetInputDim("X");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto offsets = ctx->Attrs().Get<std::vector<int>>("offsets");
if (ctx->HasInputs("ShapeTensor")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensor");
PADDLE_ENFORCE_GT(
inputs_name.size(),
0,
platform::errors::InvalidArgument(
"The number of elements of the input 'ShapeTensor' for "
"CropTensor must be greater than zero, "
"but the value received is %d.",
inputs_name.size()));
auto out_dims = std::vector<int>(inputs_name.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] > 0) {
out_dims[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_dims[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
}
}
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
return;
}
if (ctx->HasInput("Shape")) {
auto shape_dim = ctx->GetInputDim("Shape");
PADDLE_ENFORCE_EQ(shape_dim.size(),
1,
platform::errors::InvalidArgument(
"The number of dimensions of the input "
"'Shape' for CropTensor must be 1, "
"but the value received is %d.",
shape_dim.size()));
PADDLE_ENFORCE_EQ(shape_dim[0],
x_dim.size(),
platform::errors::InvalidArgument(
"The number of elements (%d) of the input 'Shape' "
"for CropTensor must be equal to the number of"
" dimensions (%d) of the input.",
shape_dim[0],
x_dim.size()));
if (ctx->IsRuntime()) {
// If true, set the shape of Output(Out) according to Input(Shape) in
// CropKernel with ExecutionContext. Also check LoD in
// CropKernel.
ctx->ShareLoD("X", /*->*/ "Out");
} else {
auto out_dims = std::vector<int>(shape_dim[0], -1);
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
}
return;
}
PADDLE_ENFORCE_EQ(
int64_t(shape.size()),
x_dim.size(),
platform::errors::InvalidArgument(
"The number of elements (%d) of attribute 'shape' for "
"CropTensor must be equal to the number of "
"dimensions (%d) of the input.",
shape.size(),
x_dim.size()));
std::vector<int64_t> out_shape(shape.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] > 0) {
out_shape[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_shape[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
}
}
ctx->SetOutputDim("Out", phi::make_ddim(out_shape));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "OffsetsTensor" ||
var_name == "Shape" || var_name == "Offsets") {
return expected_kernel_type;
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
class CropTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7).");
AddInput("Shape",
"The input used to describe shape of output, which is a "
"1-D vector whose size equals to the rank of input 'X'. The "
"elements data type must be int. It has a higher priority than "
"the shape attribute")
.AsDispensable();
AddInput("Offsets",
"The input used to describe offsets in runtime, which is a "
"1-D vector whose size equals to the rank of input 'X'. The "
"elements data type must be int. It has a higher priority than "
"the offsets attribute")
.AsDispensable();
AddInput("ShapeTensor",
"(vector<Tensor<int32>>, optional). If provided, crop_tensor will "
"use this. The shape of the tensor in vector MUST BE [1]. "
"It has the highest priority compare with Input(Shape) and "
"attr(shape).")
.AsDuplicable()
.AsDispensable();
AddInput("OffsetsTensor",
"(vector<Tensor<int32>>, optional). If provided, crop_tensor will "
"use this. The shape of the tensor in vector MUST BE [1]. "
"It has the highest priority compare with Input(Offsets) and "
"attr(offsets).")
.AsDuplicable()
.AsDispensable();
AddOutput("Out",
"The output of crop_tensor op, "
"which is of the same dimensions as X.");
AddAttr<std::vector<int>>("offsets",
"A list<int> describing offsets to be cropped. "
"The size of offsets list should be the same as "
"the dimension size of input X.")
.SetDefault(std::vector<int>());
AddAttr<std::vector<int>>("shape",
"A list<int> describing the shape of output. "
"The size of shape list should be the same as "
"the dimension size of input X.")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
CropTensor Operator.
Crop input into output, as specified by offsets and shape.
There are three ways to set the offsets:
1. Input 'OffsetsTensor: It is a tensor list. It should be set as a list that
contains tensor variable in python configure script.
This way is suitable for dynamic offsets.
2. Input 'Offsets': It is a variable and can be output of other operators.
This way is suitable for dynamic offsets.
3. Attribute 'offsets': It will be set in python configure script. This way
is suitable for fixed offsets.
You CANNOT use these three ways at the same time. An exception will be raised
if input 'OffsetsTensor' or 'Offset' is configured and meanwhile the attribute 'offsets' is
not empty.
There are three ways to set shape:
1. Input 'ShapeTensor': It is a tensor list. It should be set as a list that contains
tensor variable in python configure script. This way is suitable
for dynamic shape.
2. Input 'Shape': It is a Variable and can be output of other operators. This way is suitable
for dynamic shape.
2. Attribute 'shape': crop input X into the shape described by a list<int>. The size of shape
list should be the same as the dimension size of input X. This way is
suitable for fixed shape.
The input should be a k-D tensor(k > 0 and k < 7). As an example:
Case 1:
Given
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]],
and
offsets = [0, 1],
and
shape = [2, 2],
we get:
Out = [[1, 2],
[3, 4]].
Case 2:
Given
X = [[0, 1, 2, 5, 0]
[0, 3, 4, 6, 0]
[0, 0, 0, 0, 0]],
and offsets is a list that contains tensor variable,
in runtime offses_var' s value is 1.
offsets = [0, offsets_var],
and shape is a list that contains tensor variable,
in runtime dim's value is 2.
shape = [dim, 3]
we get:
Out = [[1, 2, 5],
[3, 4, 6]].
)DOC");
}
};
class CropTensorOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CropTensorGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"CropTensorGrad");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "OffsetsTensor" ||
var_name == "Shape" || var_name == "Offsets") {
return expected_kernel_type;
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
template <typename T>
class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("crop_tensor_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
if (this->HasInput("OffsetsTensor")) {
op->SetInput("OffsetsTensor", this->Input("OffsetsTensor"));
}
if (this->HasInput("Offsets")) {
op->SetInput("Offsets", this->Input("Offsets"));
}
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(crop_tensor,
ops::CropTensorOp,
ops::CropTensorOpMaker,
ops::CropTensorGradOpMaker<paddle::framework::OpDesc>,
ops::CropTensorGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad);
...@@ -108,18 +108,23 @@ execute_process( ...@@ -108,18 +108,23 @@ execute_process(
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml --op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path --output_op_path "${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp" "${generated_argument_mapping_path}.tmp"
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND COMMAND
${PYTHON_EXECUTABLE} generate_sparse_op.py --ops_yaml_path ${PYTHON_EXECUTABLE} generate_sparse_op.py --ops_yaml_path
./parsed_ops/sparse_ops.parsed.yaml --backward_ops_yaml_path ./parsed_ops/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_ops/sparse_backward.parsed.yaml --output_op_path ./parsed_ops/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path "${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp" "${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results) RESULT_VARIABLE _result)
foreach(_result in ${_results}) if(${_result})
if(${_result}) message(FATAL_ERROR "sparse operator codegen failed, exiting.")
message(FATAL_ERROR "operator codegen failed, exiting.") endif()
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}") if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
......
...@@ -114,17 +114,44 @@ def to_input_name(s): ...@@ -114,17 +114,44 @@ def to_input_name(s):
return match.group(2) return match.group(2)
def to_scalar_tensor_name(attr):
if 'tensor_name' in attr:
return attr['tensor_name']
return to_pascal_case(attr['name']) + 'Tensor'
def to_int_array_tensor_name(attr):
if 'tensor_name' in attr:
return attr['tensor_name']
return to_pascal_case(attr['name']) + 'Tensor'
def to_int_array_tensors_name(attr):
if 'tensors_name' in attr:
return attr['tensors_name']
return to_pascal_case(attr['name']) + 'TensorList'
def cartesian_prod_attrs(attrs): def cartesian_prod_attrs(attrs):
items = [] items = []
for attr in attrs: for attr in attrs:
type_name = attr["typename"] type_name = attr["typename"]
name = attr["name"] name = attr["name"]
if type_name == "Scalar": if type_name == "Scalar":
items.append((name, "{}Tensor".format(name))) items.append((name, to_scalar_tensor_name(attr)))
elif type_name == "IntArray": elif type_name == "IntArray":
items.append( if 'tensor_name' not in attr and 'manual_flag' in attr:
(name, "{}Tensor".format(name), "{}TensorList".format(name)) items.append((name, to_int_array_tensors_name(attr)))
) elif 'tensors_name' not in attr and 'manual_flag' in attr:
items.append((name, to_int_array_tensor_name(attr)))
else:
items.append(
(
name,
to_int_array_tensor_name(attr),
to_int_array_tensors_name(attr),
)
)
else: else:
items.append((name,)) items.append((name,))
......
...@@ -20,10 +20,13 @@ import yaml ...@@ -20,10 +20,13 @@ import yaml
from filters import ( from filters import (
cartesian_prod_mapping, cartesian_prod_mapping,
to_input_name, to_input_name,
to_int_array_tensor_name,
to_int_array_tensors_name,
to_op_attr_type, to_op_attr_type,
to_opmaker_name, to_opmaker_name,
to_opmaker_name_cstr, to_opmaker_name_cstr,
to_pascal_case, to_pascal_case,
to_scalar_tensor_name,
) )
from jinja2 import Environment, FileSystemLoader, StrictUndefined from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict from parse_utils import to_named_dict
...@@ -48,6 +51,9 @@ env = Environment( ...@@ -48,6 +51,9 @@ env = Environment(
env.filters["to_op_attr_type"] = to_op_attr_type env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_scalar_tensor_name"] = to_scalar_tensor_name
env.filters["to_int_array_tensor_name"] = to_int_array_tensor_name
env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name
env.filters["to_input_name"] = to_input_name env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
...@@ -66,6 +72,86 @@ def restruct_io(op): ...@@ -66,6 +72,86 @@ def restruct_io(op):
return op return op
def process_scalar(op_item, scalar_configs):
scalar_map = {
'Scalar': 'float',
'Scalar(float)': 'float',
'Scalar(int)': 'int',
'Scalar(int64_t)': 'int64_t',
}
if scalar_configs is not None:
for attr_item in op_item['attrs']:
if attr_item['name'] in scalar_configs:
attr_type = attr_item['typename']
assert (
attr_type in scalar_map
), f"{op_item['name']}'s scalar in op_compat.yaml is error, the data_type of {attr_item['name']} is expected to be one of Scalar, Scalar(float), Scalar(int) or Scalar(int64_t), but now is {attr_type}."
scalar_config = scalar_configs[attr_item['name']]
attr_item['is_support_tensor'] = (
True
if 'support_tensor' in scalar_config
and scalar_config['support_tensor']
else False
)
if attr_item['is_support_tensor']:
attr_item['typename'] = (
scalar_config['data_type']
if 'data_type' in scalar_config
else scalar_map[attr_type]
)
else:
attr_item['data_type'] = (
scalar_config['data_type']
if 'data_type' in scalar_config
else scalar_map[attr_type]
)
attr_item['tensor_name'] = scalar_config['tensor_name']
def process_int_array(op_item, int_array_configs):
data_type_map = {
'int': 'std::vector<int>',
'int64_t': 'std::vector<int64_t>',
}
if int_array_configs is not None:
for attr_item in op_item['attrs']:
if attr_item['name'] in int_array_configs:
attr_type = attr_item['typename']
assert (
attr_item['typename'] == "IntArray"
), f"{op_item['name']}'s int_array in op_compat.yaml is error, the data_type of {attr_item['name']} is expected to be one of IntArray, but now is {attr_type}."
int_array_config = int_array_configs[attr_item['name']]
attr_item['is_support_tensor'] = (
True
if 'support_tensor' in int_array_config
and int_array_config['support_tensor']
else False
)
if attr_item['is_support_tensor']:
attr_item['typename'] = (
data_type_map[int_array_config['data_type']]
if 'data_type' in int_array_config
else 'std::vector<int64_t>'
)
else:
attr_item['data_type'] = (
data_type_map[int_array_config['data_type']]
if 'data_type' in int_array_config
else 'std::vector<int64_t>'
)
attr_item['manual_flag'] = True
if 'tensor_name' in int_array_config:
attr_item['tensor_name'] = int_array_config[
'tensor_name'
]
if 'tensors_name' in int_array_config:
attr_item['tensors_name'] = int_array_config[
'tensors_name'
]
# replace name of op and params for OpMaker # replace name of op and params for OpMaker
def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict): def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
def get_op_and_op_name(op_item): def get_op_and_op_name(op_item):
...@@ -91,12 +177,26 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict): ...@@ -91,12 +177,26 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
if new_op_name != op_name: if new_op_name != op_name:
forward_op_item['op_name'] = op_name forward_op_item['op_name'] = op_name
scalar_configs = None
int_array_configs = None
if 'scalar' in op_args:
scalar_configs = op_args['scalar']
if 'int_array' in op_args:
int_array_configs = op_args['int_array']
process_scalar(forward_op_item, scalar_configs)
process_int_array(forward_op_item, int_array_configs)
if 'backward' in op_args and has_backward: if 'backward' in op_args and has_backward:
backward_op_list = op_args['backward'].split(',') backward_op_list = op_args['backward'].split(',')
_, bw_op_name = get_op_and_op_name(backward_op_list[0]) _, bw_op_name = get_op_and_op_name(backward_op_list[0])
forward_op_item['backward'] = bw_op_name forward_op_item['backward'] = bw_op_name
backward_op_item['op_name'] = bw_op_name backward_op_item['op_name'] = bw_op_name
process_scalar(backward_op_item, scalar_configs)
process_int_array(backward_op_item, int_array_configs)
# for double grad # for double grad
if len(backward_op_list) > 1: if len(backward_op_list) > 1:
( (
...@@ -114,6 +214,9 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict): ...@@ -114,6 +214,9 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
double_grad_item['forward']['attrs'], op_args['attrs'] double_grad_item['forward']['attrs'], op_args['attrs']
) )
process_scalar(double_grad_item, scalar_configs)
process_int_array(double_grad_item, int_array_configs)
# for triple grad # for triple grad
if len(backward_op_list) > 2: if len(backward_op_list) > 2:
( (
...@@ -132,6 +235,9 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict): ...@@ -132,6 +235,9 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
op_args['attrs'], op_args['attrs'],
) )
process_scalar(triple_grad_item, scalar_configs)
process_int_array(triple_grad_item, int_array_configs)
key_set = ['inputs', 'attrs', 'outputs'] key_set = ['inputs', 'attrs', 'outputs']
args_map = {} args_map = {}
for key in key_set: for key in key_set:
......
...@@ -20,10 +20,13 @@ import yaml ...@@ -20,10 +20,13 @@ import yaml
from filters import ( from filters import (
cartesian_prod_mapping, cartesian_prod_mapping,
to_input_name, to_input_name,
to_int_array_tensor_name,
to_int_array_tensors_name,
to_op_attr_type, to_op_attr_type,
to_opmaker_name, to_opmaker_name,
to_opmaker_name_cstr, to_opmaker_name_cstr,
to_pascal_case, to_pascal_case,
to_scalar_tensor_name,
) )
from generate_op import process_invoke_op from generate_op import process_invoke_op
from jinja2 import Environment, FileSystemLoader, StrictUndefined from jinja2 import Environment, FileSystemLoader, StrictUndefined
...@@ -49,6 +52,9 @@ env = Environment( ...@@ -49,6 +52,9 @@ env = Environment(
env.filters["to_op_attr_type"] = to_op_attr_type env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_scalar_tensor_name"] = to_scalar_tensor_name
env.filters["to_int_array_tensor_name"] = to_int_array_tensor_name
env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name
env.filters["to_input_name"] = to_input_name env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
......
...@@ -17,6 +17,7 @@ from copy import copy ...@@ -17,6 +17,7 @@ from copy import copy
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
from tests import is_attr, is_input, is_output, is_vec from tests import is_attr, is_input, is_output, is_vec
from type_mapping import opmaker_attr_types_map
def to_named_dict(items: List[Dict]) -> Dict[str, Dict]: def to_named_dict(items: List[Dict]) -> Dict[str, Dict]:
...@@ -97,6 +98,8 @@ def parse_input_and_attr( ...@@ -97,6 +98,8 @@ def parse_input_and_attr(
), f"{op_name}: Arguments with default value should not precede those without default value" ), f"{op_name}: Arguments with default value should not precede those without default value"
elif "default_value" in item: elif "default_value" in item:
met_attr_with_default_value = True met_attr_with_default_value = True
if typename.startswith('Scalar') or typename == 'IntArray':
item['data_type'] = opmaker_attr_types_map[typename]
attrs.append(item) attrs.append(item)
else: else:
raise KeyError(f"{op_name}: Invalid argument type {typename}.") raise KeyError(f"{op_name}: Invalid argument type {typename}.")
......
...@@ -61,19 +61,30 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name ...@@ -61,19 +61,30 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name
{% set name = attr["name"] %} {% set name = attr["name"] %}
{% set typename = attr["typename"] %} {% set typename = attr["typename"] %}
{% if typename is scalar %} {% if typename is scalar %}
AddInput("{{name | to_pascal_case}}Tensor", "attribute {{i}} for {{op_name}} op from 0D Tensor.") AddInput("{{attr | to_scalar_tensor_name}}", "attribute {{i}} for {{op_name}} op from 0D Tensor.")
.AsDispensable(); .AsDispensable();
AddAttr<{{attr["data_type"]}}>("{{name}}", "({{attr["data_type"]}}), attribute {{i}} for {{op_name}} op.")
{% elif typename == "IntArray" %}{# the type has been renamed #} {% elif typename == "IntArray" %}{# the type has been renamed #}
AddInput("{{name | to_pascal_case}}Tensor", "attribute {{i}} for {{op_name}} op from 1D integer Tensor.") {% if 'tensor_name' in attr or 'manual_flag' not in attr %}
AddInput("{{attr | to_int_array_tensor_name}}", "attribute {{i}} for {{op_name}} op from 1D integer Tensor.")
.AsDispensable(); .AsDispensable();
AddInput("{{name | to_pascal_case}}TensorList", "attribute {{i}} for {{op_name}} op from list fo 0D integer Tensors.") {% endif %}
{% if 'tensors_name' in attr or 'manual_flag' not in attr %}
AddInput("{{attr | to_int_array_tensors_name}}", "attribute {{i}} for {{op_name}} op from list fo 0D integer Tensors.")
.AsDuplicable() .AsDuplicable()
.AsDispensable(); .AsDispensable();
{% endif %} {% endif %}
AddAttr<{{attr["data_type"]}}>("{{name}}", "({{attr["data_type"]}}), attribute {{i}} for {{op_name}} op.")
{% else %}
AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_type}}), attribute {{i}} for {{op_name}} op.") AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_type}}), attribute {{i}} for {{op_name}} op.")
{% endif %}
{% if "default_value" in attr %} {% if "default_value" in attr %}
.SetDefault({{process_default_value(attr)}}) .SetDefault({{process_default_value(attr)}})
{%- endif %} {%- endif %}
{% if "is_support_tensor" in attr and attr["is_support_tensor"] %}
.SupportTensor()
{%- endif %}
{%- endmacro %} {%- endmacro %}
{# process default value for attributes, some attribute has different types and different default values in op & opmaker #} {# process default value for attributes, some attribute has different types and different default values in op & opmaker #}
...@@ -104,7 +115,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum ...@@ -104,7 +115,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum
paddle::small_vector<const char*> attrs; paddle::small_vector<const char*> attrs;
{% for attr in op["attrs"]%} {% for attr in op["attrs"]%}
{% filter indent(2)%} {% filter indent(2)%}
{{get_an_attr(attr)}}; {{get_an_attr(attr)}}
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(op["outputs"], kernel_args)}}; {{get_output_list(op["outputs"], kernel_args)}};
...@@ -159,7 +170,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum ...@@ -159,7 +170,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum
paddle::small_vector<const char*> attrs; paddle::small_vector<const char*> attrs;
{% for attr in op["attrs"]%} {% for attr in op["attrs"]%}
{% filter indent(2)%} {% filter indent(2)%}
{{get_an_attr(attr)}}; {{get_an_attr(attr)}}
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(op["outputs"], kernel_args)}}; {{get_output_list(op["outputs"], kernel_args)}};
...@@ -202,21 +213,28 @@ paddle::small_vector<const char*> inputs { ...@@ -202,21 +213,28 @@ paddle::small_vector<const char*> inputs {
{% set typename = attr["typename"] %} {% set typename = attr["typename"] %}
{% set name = attr["name"] %} {% set name = attr["name"] %}
{% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #} {% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #}
attrs.emplace_back( attrs.emplace_back(ctx.HasInput("{{attr | to_scalar_tensor_name}}") ? "{{attr | to_scalar_tensor_name}}" : "{{name}}");
ctx.HasInput("{{name | to_pascal_case}}")
? "{{name | to_pascal_case}}Tensor"
: "{{name}}"
)
{%- elif typename == "IntArray" %} {%- elif typename == "IntArray" %}
{% if 'tensor_name' in attr and 'tensors_name' not in attr %}
attrs.emplace_back(
ctx.HasInput("{{attr | to_int_array_tensor_name}}")
? "{{attr | to_int_array_tensor_name}}"
: "{{name}}");
{% elif 'tensor_name' not in attr and 'tensors_name' in attr %}
attrs.emplace_back( attrs.emplace_back(
ctx.HasInput("{{name | to_pascal_case}}Tensor") ctx.InputSize("{{attr | to_int_array_tensors_name}}") > 0
? "{{name | to_pascal_case}}Tensor" ? "{{attr | to_int_array_tensors_name}}"
: ctx.InputSize("{{name | to_pascal_case}}TensorList") > 0 : "{{name}}");
? "{{name | to_pascal_case}}TensorList" {% else %}
: "{{name}}" attrs.emplace_back(
) ctx.HasInput("{{attr | to_int_array_tensor_name}}")
? "{{attr | to_int_array_tensor_name}}"
: ctx.InputSize("{{attr | to_int_array_tensors_name}}") > 0
? "{{attr | to_int_array_tensors_name}}"
: "{{name}}");
{%- endif %}
{%- else %} {%- else %}
attrs.emplace_back("{{name}}") attrs.emplace_back("{{name}}");
{%- endif %} {%- endif %}
{%- endmacro %} {%- endmacro %}
...@@ -394,10 +412,20 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -394,10 +412,20 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% set attr_name = attr["name"] %} {% set attr_name = attr["name"] %}
{% if attr_name in forward_attr_names %} {% if attr_name in forward_attr_names %}
{% if attr["typename"] == "IntArray" %} {% if attr["typename"] == "IntArray" %}
grad_op->SetInput("{{attr_name | to_pascal_case}}Tensor", this->Input("{{attr_name | to_pascal_case}}Tensor")); {% if 'tensor_name' in attr or 'manual_flag' not in attr %}
grad_op->SetInput("{{attr_name | to_pascal_case}}TensorList", this->Input("{{attr_name | to_pascal_case}}TensorList")); if (this->HasInput("{{attr | to_int_array_tensor_name}}")) {
grad_op->SetInput("{{attr | to_int_array_tensor_name}}", this->Input("{{attr | to_int_array_tensor_name}}"));
}
{% endif %}
{% if 'tensors_name' in attr or 'manual_flag' not in attr %}
if (this->HasInput("{{attr | to_int_array_tensors_name}}")) {
grad_op->SetInput("{{attr | to_int_array_tensors_name}}", this->Input("{{attr | to_int_array_tensors_name}}"));
}
{% endif %}
{% elif attr["typename"] == "Scalar" %} {% elif attr["typename"] == "Scalar" %}
grad_op->SetInput("{{attr_name | to_pascal_case}}Tensor", this->Input("{{attr_name | to_pascal_case}}Tensor")); if (this->HasInput("{{attr | to_scalar_tensor_name}}")) {
grad_op->SetInput("{{attr | to_scalar_tensor_name}}", this->Input("{{attr | to_scalar_tensor_name}}"));
}
{% endif %} {% endif %}
{% else %}{# maybe something wrong: backward op has more attrs than the forward one#} {% else %}{# maybe something wrong: backward op has more attrs than the forward one#}
grad_op->SetAttr("{{attr_name}}", {{process_default_value(attr)}}); grad_op->SetAttr("{{attr_name}}", {{process_default_value(attr)}});
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "A tensor contains probabilities of categories");
AddOutput("Out", "The output tensor of multinomial op");
AddAttr<int>("num_samples", "number of the generated samples")
.SetDefault(1)
.SupportTensor();
AddAttr<bool>("replacement", "can a category be sampled more than once")
.SetDefault(false);
AddComment(R"DOC(
This OP returns a Tensor filled with the sampled categoris according to Multinomial probabilities.
Out ~ Multinomial(X)
)DOC");
}
};
class MultinomialOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(multinomial,
MultinomialInferShapeFunctor,
PD_INFER_META(phi::MultinomialInferMeta));
REGISTER_OPERATOR(
multinomial,
ops::MultinomialOp,
ops::MultinomialOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
MultinomialInferShapeFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class TopkV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context(),
layout_,
library_);
}
};
class TopkV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input of Topk op");
AddInput("K",
"(Tensor) Number of top elements to look for along "
"the last dimension (along each row for matrices).")
.AsDispensable();
AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddComment(R"DOC(
Top K operator
If the input is a vector (1d tensor), this operator finds the k largest
entries in the vector and outputs their values and indices as vectors.
Thus values[j] is the j-th largest entry in input, and its index is indices[j].
For matrices, this operator computes the top k entries in each row. )DOC");
AddAttr<int>("k",
"(int, default 1) Number of top elements to look for along "
"the tensor).")
.SetDefault(1);
AddAttr<int>("axis",
"the axis to sort and get the k indices, value."
"if not set, will get k value in last axis.")
.SetDefault(-1);
AddAttr<bool>("largest",
"control flag whether to return largest or smallest")
.SetDefault(true);
AddAttr<bool>("sorted",
"control flag whether to return elements in sorted order")
.SetDefault(true);
}
};
class TopkV2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"),
true,
platform::errors::InvalidArgument("Input(X) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Indices"),
true,
platform::errors::InvalidArgument("Input(Indices) should be not null"));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")),
true,
platform::errors::InvalidArgument(
"Grad Input(Out) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")),
true,
platform::errors::InvalidArgument("Grad Output(X) should be not null"));
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class TopkV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("top_k_v2_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
op->SetInput("Indices", this->Output("Indices"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(top_k_v2,
TopKInferShapeFunctor,
PD_INFER_META(phi::TopKInferMeta));
REGISTER_OPERATOR(top_k_v2,
ops::TopkV2Op,
ops::TopkV2OpMaker,
ops::TopkV2GradOpMaker<paddle::framework::OpDesc>,
ops::TopkV2GradOpMaker<paddle::imperative::OpBase>,
TopKInferShapeFunctor);
REGISTER_OPERATOR(top_k_v2_grad, ops::TopkV2OpGrad);
...@@ -226,6 +226,16 @@ ...@@ -226,6 +226,16 @@
func : cosh_grad func : cosh_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : crop_grad
forward : crop (Tensor x, IntArray shape, IntArray offsets) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray offsets)
output : Tensor(x_grad)
infer_meta :
func : CropGradInferMeta
kernel :
func : crop_grad
data_type : x
- backward_op : cross_grad - backward_op : cross_grad
forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out) forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis) args : (Tensor x, Tensor y, Tensor out_grad, int axis)
...@@ -1121,6 +1131,17 @@ ...@@ -1121,6 +1131,17 @@
func : thresholded_relu_grad func : thresholded_relu_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : topk_grad
forward : topk (Tensor x, Scalar k, int axis = -1, bool largest = true, bool sorted = true) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, Scalar k, int axis, bool largest, bool sorted)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : topk_grad
data_type : out_grad
- backward_op : trace_grad - backward_op : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2) args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
......
...@@ -325,16 +325,6 @@ ...@@ -325,16 +325,6 @@
kernel : kernel :
func : conv3d_transpose_grad func : conv3d_transpose_grad
- backward_op : crop_grad
forward : crop (Tensor x, IntArray shape, IntArray offsets) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray offsets)
output : Tensor(x_grad)
infer_meta :
func : CropGradInferMeta
kernel :
func : crop_grad
data_type : x
- backward_op : cross_entropy_with_softmax_grad - backward_op : cross_entropy_with_softmax_grad
forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) -> Tensor(softmax), Tensor(loss) forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) -> Tensor(softmax), Tensor(loss)
args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis)
...@@ -1655,16 +1645,6 @@ ...@@ -1655,16 +1645,6 @@
no_need_buffer : x no_need_buffer : x
backward : tile_double_grad backward : tile_double_grad
- backward_op : topk_grad
forward : topk (Tensor x, Scalar k, int axis = -1, bool largest = true, bool sorted = true) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, Scalar k = -1, int axis = -1, bool largest = true, bool sorted = true)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : topk_grad
- backward_op : transpose_double_grad - backward_op : transpose_double_grad
forward : transpose_grad (Tensor grad_out, int[] perm) -> Tensor(grad_x) forward : transpose_grad (Tensor grad_out, int[] perm) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int[] perm) args : (Tensor grad_x_grad, int[] perm)
......
...@@ -287,7 +287,7 @@ ...@@ -287,7 +287,7 @@
backward : bilinear_tensor_product_grad backward : bilinear_tensor_product_grad
- op : bincount - op : bincount
args: (Tensor x, Tensor weights, Scalar minlength) args: (Tensor x, Tensor weights, Scalar(int) minlength = 0)
output: Tensor(out) output: Tensor(out)
infer_meta: infer_meta:
func: BincountInferMeta func: BincountInferMeta
...@@ -464,16 +464,6 @@ ...@@ -464,16 +464,6 @@
output : Tensor(out) output : Tensor(out)
invoke : copy_to_impl(x, place, blocking) invoke : copy_to_impl(x, place, blocking)
- op : crop
args : (Tensor x, IntArray shape, IntArray offsets)
output : Tensor(out)
infer_meta :
func : CropInferMeta
kernel :
func : crop
data_type : x
backward : crop_grad
# Part of python API paddle.nn.functional.cross_entropy # Part of python API paddle.nn.functional.cross_entropy
- op : cross_entropy_with_softmax - op : cross_entropy_with_softmax
args : (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) args : (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis)
...@@ -1358,14 +1348,6 @@ ...@@ -1358,14 +1348,6 @@
func : multiclass_nms3 func : multiclass_nms3
optional : rois_num optional : rois_num
- op : multinomial
args : (Tensor x, Scalar num_samples, bool replacement)
output : Tensor(out)
infer_meta :
func : MultinomialInferMeta
kernel :
func : multinomial
- op : multiplex - op : multiplex
args : (Tensor[] inputs, Tensor index) args : (Tensor[] inputs, Tensor index)
output : Tensor output : Tensor
...@@ -2045,15 +2027,6 @@ ...@@ -2045,15 +2027,6 @@
func : tile func : tile
backward : tile_grad backward : tile_grad
- op : topk
args : (Tensor x, Scalar k, int axis = -1, bool largest = true, bool sorted = true)
output : Tensor(out), Tensor(indices)
infer_meta :
func : TopKInferMeta
kernel :
func : topk
backward : topk_grad
- op : transpose - op : transpose
args : (Tensor x, int[] perm) args : (Tensor x, int[] perm)
output : Tensor output : Tensor
......
...@@ -252,6 +252,22 @@ ...@@ -252,6 +252,22 @@
extra : extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : crop (crop_tensor)
backward : crop_grad (crop_tensor_grad)
inputs :
x : X
outputs :
out : Out
int_array:
shape :
data_type : int
tensor_name : Shape
tensors_name : ShapeTensor
offsets :
data_type : int
tensor_name : Offsets
tensors_name : OffsetsTensor
- op : cross - op : cross
inputs : inputs :
{x : X, y : Y} {x : X, y : Y}
...@@ -823,6 +839,16 @@ ...@@ -823,6 +839,16 @@
outputs : outputs :
{out : Out, indices : Indices} {out : Out, indices : Indices}
- op : multinomial
inputs :
{x : X}
outputs :
out : Out
scalar :
num_samples :
data_type : int
support_tensor : true
- op : multiply (elementwise_mul) - op : multiply (elementwise_mul)
backward : multiply_grad (elementwise_mul_grad) backward : multiply_grad (elementwise_mul_grad)
extra : extra :
...@@ -1193,6 +1219,17 @@ ...@@ -1193,6 +1219,17 @@
outputs : outputs :
out : Out out : Out
- op : topk (top_k_v2)
backward : topk_grad (top_k_v2_grad)
inputs :
x : X
outputs :
{out : Out, indices : Indices}
scalar :
k :
data_type : int
tensor_name : K
- op : trace - op : trace
inputs : inputs :
x : Input x : Input
......
...@@ -179,6 +179,16 @@ ...@@ -179,6 +179,16 @@
func : cosh func : cosh
backward : cosh_grad backward : cosh_grad
- op : crop
args : (Tensor x, IntArray shape = {}, IntArray offsets = {})
output : Tensor(out)
infer_meta :
func : CropInferMeta
kernel :
func : crop
data_type : x
backward : crop_grad
- op : cross - op : cross
args : (Tensor x, Tensor y, int axis = 9) args : (Tensor x, Tensor y, int axis = 9)
output : Tensor output : Tensor
...@@ -684,6 +694,15 @@ ...@@ -684,6 +694,15 @@
func : mode func : mode
backward : mode_grad backward : mode_grad
- op : multinomial
args : (Tensor x, Scalar(int) num_samples = 1, bool replacement = false)
output : Tensor(out)
infer_meta :
func : MultinomialInferMeta
kernel :
func : multinomial
data_type : x
- op : mv - op : mv
args : (Tensor x, Tensor vec) args : (Tensor x, Tensor vec)
output : Tensor output : Tensor
...@@ -926,6 +945,16 @@ ...@@ -926,6 +945,16 @@
func : thresholded_relu func : thresholded_relu
backward : thresholded_relu_grad backward : thresholded_relu_grad
- op : topk
args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true)
output : Tensor(out), Tensor(indices)
infer_meta :
func : TopKInferMeta
kernel :
func : topk
data_type : x
backward : topk_grad
- op : trace - op : trace
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1) args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor output : Tensor
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature CropTensorOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("ShapeTensor") > 0) {
if (ctx.InputSize("OffsetsTensor") > 0) {
return KernelSignature(
"crop", {"X"}, {"ShapeTensor", "OffsetsTensor"}, {"Out"});
} else if (ctx.HasInput("Offsets")) {
return KernelSignature(
"crop", {"X"}, {"ShapeTensor", "Offsets"}, {"Out"});
} else {
return KernelSignature(
"crop", {"X"}, {"ShapeTensor", "offsets"}, {"Out"});
}
} else if (ctx.HasInput("Shape")) {
if (ctx.InputSize("OffsetsTensor") > 0) {
return KernelSignature(
"crop", {"X"}, {"Shape", "OffsetsTensor"}, {"Out"});
} else if (ctx.HasInput("Offsets")) {
return KernelSignature("crop", {"X"}, {"Shape", "Offsets"}, {"Out"});
} else {
return KernelSignature("crop", {"X"}, {"Shape", "offsets"}, {"Out"});
}
} else {
if (ctx.InputSize("OffsetsTensor") > 0) {
return KernelSignature(
"crop", {"X"}, {"shape", "OffsetsTensor"}, {"Out"});
} else if (ctx.HasInput("Offsets")) {
return KernelSignature("crop", {"X"}, {"shape", "Offsets"}, {"Out"});
} else {
return KernelSignature("crop", {"X"}, {"shape", "offsets"}, {"Out"});
}
}
}
KernelSignature CropTensorGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.InputSize("OffsetsTensor") > 0) {
return KernelSignature(
"crop_grad", {"X", "Out@GRAD"}, {"OffsetsTensor"}, {"X@GRAD"});
} else if (ctx.HasInput("Offsets")) {
return KernelSignature(
"crop_grad", {"X", "Out@GRAD"}, {"Offsets"}, {"X@GRAD"});
} else {
return KernelSignature(
"crop_grad", {"X", "Out@GRAD"}, {"offsets"}, {"X@GRAD"});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(crop_tensor, crop);
PD_REGISTER_BASE_KERNEL_NAME(crop_tensor_grad, crop_grad);
PD_REGISTER_ARG_MAPPING_FN(crop_tensor, phi::CropTensorOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(crop_tensor_grad,
phi::CropTensorGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("K")) {
return KernelSignature(
"topk", {"X"}, {"K", "axis", "largest", "sorted"}, {"Out", "Indices"});
} else {
return KernelSignature(
"topk", {"X"}, {"k", "axis", "largest", "sorted"}, {"Out", "Indices"});
}
}
KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("topk_grad",
{"X", "Indices", "Out@GRAD"},
{"k", "axis", "largest", "sorted"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(top_k_v2, topk);
PD_REGISTER_BASE_KERNEL_NAME(top_k_v2_grad, topk_grad);
PD_REGISTER_ARG_MAPPING_FN(top_k_v2, phi::TopkOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(top_k_v2_grad, phi::TopkGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册