提交 8f907058 编写于 作者: J Jiabin Yang 提交者: GitHub

Revert "Prim api gen (#49654)"

This reverts commit 813e27c9.
上级 f71f77e9
generated/prim_api/eager_prim_api.cc
generated/prim_api/tmp_eager_prim_api.cc
generated/prim_api/*.h
add_subdirectory(auto_code_generated)
add_subdirectory(manual) add_subdirectory(manual)
add_subdirectory(generated)
if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library( cc_library(
prim_api prim_api
......
...@@ -13,6 +13,6 @@ ...@@ -13,6 +13,6 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/fluid/prim/api/manual/utils/utils.h"
set(api_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml"
)
set(legacy_api_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
)
set(tmp_eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_eager_prim_api.cc"
)
set(tmp_prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_prim_api.h"
)
set(eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc"
)
set(prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/prim_api.h")
set(prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/prim_gen.py)
message("prim api Code gen")
execute_process(
WORKING_DIRECTORY
${CMAKE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated
COMMAND
${PYTHON_EXECUTABLE} ${prim_api_gen_file} --api_yaml_path
${legacy_api_yaml_path} ${api_yaml_path} --prim_api_header_path
${tmp_prim_api_h_path} --eager_prim_api_source_path
${tmp_eager_prim_api_cc_path}
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "prim api genrate failed, exiting.")
endif()
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
${tmp_prim_api_h_path} ${prim_api_h_path})
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
${tmp_eager_prim_api_cc_path} ${eager_prim_api_cc_path})
message("copy tmp_xxx_prim_api to xxx_prim_api")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# prim api list
white_ops_list = [
"pow",
"scale",
"multiply",
"unsqueeze",
"expand",
"full",
"reshape",
"divide",
"sum",
"exp",
]
inplace_out_type_map = {
"Tensor": "Tensor&",
"std::vector<Tensor>": "std::vector<Tensor>&",
}
inplace_optional_out_type_map = {
"Tensor": "paddle::optional<Tensor>&",
"std::vector<Tensor>": "paddle::optional<std::vector<Tensor>>&",
}
class BaseAPI:
def __init__(self, api_item_yaml):
# self.api = api_item_yaml['op']
self.api = api_item_yaml['name']
self.is_prim_api = False
if api_item_yaml['name'] in white_ops_list:
self.is_prim_api = True
#######################################
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
# outputs:
# names : [], list of output names
# types : [], list of output types
# out_size_expr : [], expression for getting size of vector<Tensor>
########################################
if self.is_prim_api:
(
self.inputs,
self.attrs,
self.outputs,
self.optional_vars,
) = self.parse_args(self.api, api_item_yaml)
self.inplace_map = api_item_yaml['inplace']
def get_api_func_name(self):
return self.api
# def is_inplace(self):
# if self.inplace_map
# return True
# return False
def get_input_tensor_args(self, inplace_flag=False):
input_args = []
inplace_type_map = {
"const Tensor&": "Tensor&",
"const paddle::optional<Tensor>&": "paddle::optional<Tensor>&",
"const std::vector<Tensor>&": "std::vector<Tensor>&",
"const paddle::optional<std::vector<Tensor>>&": "paddle::optional<std::vector<Tensor>>&",
}
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(
inplace_type_map[self.inputs['input_info'][name]]
+ ' '
+ name
)
else:
input_args.append(self.inputs['input_info'][name] + ' ' + name)
return input_args
def get_declare_args(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
declare_args.append(
self.attrs['attr_info'][name][0] + ' ' + name + default_value
)
return ", ".join(declare_args)
def get_declare_args_nodefault(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name)
return ", ".join(declare_args)
def get_return_type(self, inplace_flag=False):
out_type_list = []
for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map:
if self.inplace_map[out_name] in self.optional_vars:
out_type_list.append(
inplace_optional_out_type_map[out_type]
)
else:
out_type_list.append(inplace_out_type_map[out_type])
else:
out_type_list.append(out_type)
if len(out_type_list) == 1:
return out_type_list[0]
else:
return "std::tuple<" + ", ".join(out_type_list) + ">"
def parse_args(self, api_name, api_item_yaml):
optional_vars = []
for input_dict in api_item_yaml['inputs']:
if input_dict['optional']:
optional_vars.append(input_dict['name'])
inputs, attrs = self.parse_input_and_attr(
api_item_yaml['inputs'], api_item_yaml['attrs']
)
output_type_list, output_names, out_size_expr = self.parse_output(
api_item_yaml['outputs']
)
return (
inputs,
attrs,
{
'names': output_names,
'types': output_type_list,
'out_size_expr': out_size_expr,
},
optional_vars,
)
def parse_input_and_attr(self, inputs_list, attrs_list):
input_types_map = {
'Tensor': 'const Tensor&',
'Tensor[]': 'const std::vector<Tensor>&',
}
attr_types_map = {
'IntArray': 'const IntArray&',
'Scalar': 'const Scalar&',
'Scalar(int)': 'const Scalar&',
'Scalar(int64_t)': 'const Scalar&',
'Scalar(float)': 'const Scalar&',
'Scalar(dobule)': 'const Scalar&',
'Scalar[]': 'const std::vector<phi::Scalar>&',
'int': 'int',
'int32_t': 'int32_t',
'int64_t': 'int64_t',
'long': 'long',
'size_t': 'size_t',
'float': 'float',
'float[]': 'const std::vector<float>&',
'double': 'double',
'bool': 'bool',
'bool[]': 'const std::vector<bool>&',
'str': 'const std::string&',
'str[]': 'const std::vector<std::string>&',
'Place': 'const Place&',
'DataLayout': 'DataLayout',
'DataType': 'DataType',
'int64_t[]': 'const std::vector<int64_t>&',
'int[]': 'const std::vector<int>&',
}
optional_types_trans = {
'Tensor': 'const paddle::optional<Tensor>&',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
'int': 'paddle::optional<int>',
'int32_t': 'paddle::optional<int32_t>',
'int64_t': 'paddle::optional<int64_t>',
'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>',
'bool': 'paddle::optional<bool>',
'Place': 'paddle::optional<const Place&>',
'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>',
}
inputs = {'names': [], 'input_info': {}}
for input_dict in inputs_list:
inputs['names'].append(input_dict['name'])
if input_dict['optional']:
inputs['input_info'][input_dict['name']] = optional_types_trans[
input_dict['typename']
]
else:
inputs['input_info'][input_dict['name']] = input_types_map[
input_dict['typename']
]
attrs = {'names': [], 'attr_info': {}}
for attr_dict in attrs_list:
attrs['names'].append(attr_dict['name'])
if 'default_value' in attr_dict.keys():
default_value = attr_dict['default_value']
else:
default_value = None
if 'optional' in attr_dict.keys():
attrs['attr_info'][attr_dict['name']] = (
optional_types_trans[attr_dict['typename']],
default_value,
)
else:
attrs['attr_info'][attr_dict['name']] = (
attr_types_map[attr_dict['typename']],
default_value,
)
return inputs, attrs
def parse_output(self, outputs_list):
out_type_list = []
out_name_list = []
out_size_expr_list = []
for output_dict in outputs_list:
if output_dict['intermediate']:
continue
out_type_list.append(output_dict['typename'])
out_name_list.append(output_dict['name'])
if 'size' in output_dict.keys():
out_size_expr_list.append(output_dict['size'])
else:
out_size_expr_list.append(None)
return out_type_list, out_name_list, out_size_expr_list
class EagerPrimAPI(BaseAPI):
def __init__(self, api_item_yaml):
super().__init__(api_item_yaml)
def get_api__func_name(self):
api_func_name = self.api
# if self.is_inplace:
# if api_func_name[-1] != '_':
# api_func_name += '_'
# print("after api name", api_func_name)
return api_func_name
def gene_prim_api_declaration(self):
api_declaration = ""
api_func_name = self.get_api__func_name()
if api_func_name[-1] != '_':
api_declaration = f"""
template <typename T>
{self.get_return_type()} {api_func_name}({self.get_declare_args()});
"""
else:
api_declaration = (
api_declaration
+ f"""
template <typename T>
{self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
"""
)
return api_declaration
def get_ad_func_input_args(self, inplace_flag=False):
input_args = []
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(name)
else:
input_args.append(name)
return input_args
def get_ad_func_args(self, inplace_flag=False):
ad_func_args = self.get_ad_func_input_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
ad_func_args.append(name)
ad_func_args_str = ", ".join(ad_func_args)
return ad_func_args_str
def gene_ad_func_call(self):
api_func_name = self.get_api__func_name()
dygraph_ad_func_name = '::' + api_func_name + '_ad_func'
dygraph_ad_func_parameters = self.get_ad_func_args()
ad_func_call_str = f"""
VLOG(4) << "Eager Prim API {api_func_name}_ad_func call";
return {dygraph_ad_func_name}({dygraph_ad_func_parameters});
"""
# print("ad_func_call_str: ", ad_func_call_str)
return ad_func_call_str
def gene_eager_prim_api_code(self):
api_code = ""
indent = " "
api_func_name = self.get_api__func_name()
template = '<Tensor>'
# func decalaration
if api_func_name[-1] != '_':
api_code = f"""
template <>
{self.get_return_type()} {api_func_name}{template}({self.get_declare_args_nodefault()})
"""
else:
api_code = f"""
template <>
{self.get_return_type(inplace_flag=True)} {api_func_name}{template}({self.get_declare_args_nodefault(inplace_flag=True)})
"""
# func code
api_code = api_code + '{'
api_code += f"""{self.gene_ad_func_call()}"""
api_code += '}' + '\n'
return api_code
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import yaml
from prim_base import EagerPrimAPI
def header_include():
return """
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/place.h"
#include "paddle/utils/optional.h"
"""
def eager_source_include(header_file_path):
return """
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
"""
def api_namespace():
return (
"""
namespace paddle {
namespace prim {
""",
"""
using Tensor = paddle::experimental::Tensor;
using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = paddle::experimental::DataType;
""",
"""
} // namespace prim
} // namespace paddle
""",
)
def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path):
apis = []
for each_api_yaml in api_yaml_path:
with open(each_api_yaml, 'r') as f:
api_list = yaml.load(f, Loader=yaml.FullLoader)
if api_list:
apis.extend(api_list)
header_file = open(header_file_path, 'w')
eager_prim_source_file = open(eager_prim_source_file_path, 'w')
namespace = api_namespace()
header_file.write("#pragma once\n")
header_file.write(header_include())
header_file.write(namespace[0])
header_file.write(namespace[1])
include_header_file = (
"#include paddle/fluid/prim/api/generated/prim_api/prim_api.h"
)
eager_prim_source_file.write(eager_source_include(include_header_file))
eager_prim_source_file.write(namespace[0])
for api in apis:
prim_api = EagerPrimAPI(api)
if prim_api.is_prim_api:
header_file.write(prim_api.gene_prim_api_declaration())
eager_prim_source_file.write(prim_api.gene_eager_prim_api_code())
header_file.write(namespace[2])
eager_prim_source_file.write(namespace[2])
header_file.close()
eager_prim_source_file.close()
def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ API files'
)
parser.add_argument(
'--api_yaml_path',
help='path to api yaml file',
nargs='+',
default=['paddle/phi/api/yaml/ops.yaml'],
)
parser.add_argument(
'--prim_api_header_path',
help='output of generated prim_api header code file',
default='paddle/fluid/prim/api/generated/prim_api/prim_api.h',
)
parser.add_argument(
'--eager_prim_api_source_path',
help='output of generated eager_prim_api source code file',
default='paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc',
)
options = parser.parse_args()
api_yaml_path = options.api_yaml_path
prim_api_header_file_path = options.prim_api_header_path
eager_prim_api_source_file_path = options.eager_prim_api_source_path
generate_api(
api_yaml_path,
prim_api_header_file_path,
eager_prim_api_source_file_path,
)
if __name__ == '__main__':
main()
add_subdirectory(prim_api)
add_subdirectory(utils) add_subdirectory(utils)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/phi/capi/include/wrapper_base.h"
namespace paddle {
namespace prim {
template <>
Tensor pow<Tensor>(const Tensor& x, const paddle::experimental::Scalar& y) {
return ::pow_ad_func(x, y);
}
template <>
Tensor scale<Tensor>(const Tensor& x,
const paddle::experimental::Scalar& scale,
float bias,
bool bias_after_scale) {
return ::scale_ad_func(x, scale, bias, bias_after_scale);
}
template <>
Tensor multiply<Tensor>(const Tensor& x, const Tensor& y) {
return ::multiply_ad_func(x, y);
}
template <>
Tensor expand<Tensor>(const Tensor& x, const IntArray& shape) {
return ::expand_ad_func(x, shape);
}
template <>
Tensor unsqueeze<Tensor>(const Tensor& x, const IntArray& axis) {
return ::unsqueeze_ad_func(x, axis);
}
template <>
Tensor divide<Tensor>(const Tensor& x, const Tensor& y) {
return ::divide_ad_func(x, y);
}
template <>
Tensor full<Tensor>(paddle::experimental::IntArray shape,
paddle::experimental::Scalar value,
paddle::experimental::DataType dtype,
paddle::platform::Place place) {
return ::full_ad_func(shape, value, dtype, place);
}
template <>
Tensor sum<Tensor>(Tensor x, IntArray axis, DataType dtype, bool keepdim) {
return ::sum_ad_func(x, axis, dtype, keepdim);
}
template <>
Tensor reshape<Tensor>(Tensor x, IntArray shape) {
return ::reshape_ad_func(x, shape);
}
template <>
Tensor exp<Tensor>(const Tensor& x) {
return ::exp_ad_func(x);
}
template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape) {
return ::expand_ad_func(x, shape);
}
} // namespace prim
} // namespace paddle
...@@ -12,15 +12,56 @@ ...@@ -12,15 +12,56 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// prim api which can't be generated
#pragma once #pragma once
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
namespace paddle { namespace paddle {
namespace prim {} // namespace prim namespace prim {
using Tensor = paddle::experimental::Tensor;
using IntArray = paddle::experimental::IntArray;
using Scalar = paddle::experimental::Scalar;
template <typename T>
Tensor pow(const Tensor& x, const Scalar& y);
template <typename T>
Tensor scale(const Tensor& X,
const Scalar& scale,
float bias,
bool bias_after_scale);
template <typename T>
Tensor multiply(const Tensor& x, const Tensor& y);
template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor unsqueeze(const Tensor& x, const IntArray& axis);
template <typename T>
Tensor divide(const Tensor& x, const Tensor& y);
template <typename T>
Tensor full(IntArray shape,
Scalar value,
DataType dtype = DataType::FLOAT32,
Place place = CPUPlace());
template <typename T>
Tensor sum(Tensor x,
IntArray axis = {},
DataType dtype = DataType::UNDEFINED,
bool keepdim = false);
template <typename T>
Tensor reshape(Tensor x, IntArray shape);
template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor exp(const Tensor& x);
} // namespace prim
} // namespace paddle } // namespace paddle
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
...@@ -38,7 +37,7 @@ namespace paddle { ...@@ -38,7 +37,7 @@ namespace paddle {
namespace prim { namespace prim {
template <> template <>
Tensor pow<DescTensor>(const Tensor& x, const Scalar& y) { Tensor pow<DescTensor>(const Tensor& x, const paddle::experimental::Scalar& y) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place()); Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp(); framework::OpDesc* op = block->AppendOp();
...@@ -56,7 +55,7 @@ Tensor pow<DescTensor>(const Tensor& x, const Scalar& y) { ...@@ -56,7 +55,7 @@ Tensor pow<DescTensor>(const Tensor& x, const Scalar& y) {
template <> template <>
Tensor scale<DescTensor>(const Tensor& x, Tensor scale<DescTensor>(const Tensor& x,
const Scalar& scale, const paddle::experimental::Scalar& scale,
float bias, float bias,
bool bias_after_scale) { bool bias_after_scale) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place()); Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
...@@ -96,63 +95,63 @@ Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) { ...@@ -96,63 +95,63 @@ Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) {
} }
template <> template <>
Tensor unsqueeze<DescTensor>(const Tensor& x, const IntArray& axis) { Tensor expand<DescTensor>(const Tensor& x, const IntArray& shape) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place()); Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp(); framework::OpDesc* op = block->AppendOp();
op->SetType("unsqueeze2"); op->SetType("expand_v2");
op->SetInput("X", op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()}); {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
std::vector<int> new_shape(axis.GetData().begin(), axis.GetData().end()); std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
op->SetAttr("axes", new_shape); op->SetAttr("shape", new_shape);
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
return out; return out;
} }
template <> template <>
Tensor expand<DescTensor>(const Tensor& x, const IntArray& shape) { Tensor divide<DescTensor>(const Tensor& x, const Tensor& y) {
// Grad infershape
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place()); Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp(); framework::OpDesc* op = block->AppendOp();
op->SetType("expand_v2"); op->SetType("elementwise_div");
op->SetInput("X", op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()}); {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y",
{std::static_pointer_cast<prim::DescTensor>(y.impl())->Name()});
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
op->SetAttr("shape", new_shape);
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block);
return out; return out;
} }
template <> template <>
Tensor divide<DescTensor>(const Tensor& x, const Tensor& y) { Tensor unsqueeze<DescTensor>(const Tensor& x, const IntArray& axis) {
// Grad infershape
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place()); Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp(); framework::OpDesc* op = block->AppendOp();
op->SetType("elementwise_div"); op->SetType("unsqueeze2");
op->SetInput("X", op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()}); {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y",
{std::static_pointer_cast<prim::DescTensor>(y.impl())->Name()});
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
std::vector<int> new_shape(axis.GetData().begin(), axis.GetData().end());
op->SetAttr("axes", new_shape);
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block);
return out; return out;
} }
template <> template <>
Tensor full<DescTensor>(const IntArray& shape, Tensor full<DescTensor>(paddle::experimental::IntArray shape,
const Scalar& value, paddle::experimental::Scalar value,
DataType dtype, paddle::experimental::DataType dtype,
const Place& place) { paddle::platform::Place place) {
// Grad infershape // Grad infershape
Tensor out = empty<DescTensor>({}, dtype, place); Tensor out = empty<DescTensor>({}, dtype, place);
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
...@@ -160,8 +159,9 @@ Tensor full<DescTensor>(const IntArray& shape, ...@@ -160,8 +159,9 @@ Tensor full<DescTensor>(const IntArray& shape,
op->SetType("fill_constant"); op->SetType("fill_constant");
op->SetAttr("shape", shape.GetData()); op->SetAttr("shape", shape.GetData());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
((dtype == DataType::FLOAT32) || (dtype == DataType::FLOAT64) || ((dtype == paddle::experimental::DataType::FLOAT32) ||
(dtype == DataType::FLOAT16)), (dtype == paddle::experimental::DataType::FLOAT64) ||
(dtype == paddle::experimental::DataType::FLOAT16)),
true, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s", "We only support float32/float16 for full, but we got data type: %s",
...@@ -177,9 +177,9 @@ Tensor full<DescTensor>(const IntArray& shape, ...@@ -177,9 +177,9 @@ Tensor full<DescTensor>(const IntArray& shape,
} }
template <> template <>
Tensor sum<DescTensor>(const Tensor& x, Tensor sum<DescTensor>(Tensor x,
const IntArray& axis, paddle::experimental::IntArray axis,
DataType dtype, paddle::experimental::DataType dtype,
bool keepdim) { bool keepdim) {
// Grad infershape // Grad infershape
Tensor out = empty<DescTensor>({}, dtype, paddle::Place()); Tensor out = empty<DescTensor>({}, dtype, paddle::Place());
...@@ -204,7 +204,7 @@ Tensor sum<DescTensor>(const Tensor& x, ...@@ -204,7 +204,7 @@ Tensor sum<DescTensor>(const Tensor& x,
} }
template <> template <>
Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) { Tensor reshape<DescTensor>(Tensor x, paddle::experimental::IntArray shape) {
// Grad infershape // Grad infershape
Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place()); Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册