提交 16abb325 编写于 作者: C cxxly 提交者: Xiaoxu Chen

move split/reshape prim api to auto generated file

上级 cdc5896f
......@@ -96,6 +96,7 @@ if(NOT DEFINED CBLAS_PROVIDER)
STATUS
"Found OpenBLAS (include: ${OPENBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})"
)
message(
STATUS "Found lapack in OpenBLAS (include: ${OPENBLAS_LAPACKE_INC_DIR})"
)
......
......@@ -43,3 +43,5 @@
- sin
- cos
- where
- reshape
- split
......@@ -225,14 +225,20 @@ class BaseAPI:
return inputs, attrs
def parse_output(self, outputs_list):
output_types_map = {
'Tensor[]': 'std::vector<Tensor>',
}
out_type_list = []
out_name_list = []
out_size_expr_list = []
for output_dict in outputs_list:
if output_dict['intermediate']:
continue
out_type_list.append(output_dict['typename'])
out_type_list.append(
output_types_map.get(
output_dict['typename'], output_dict['typename']
)
)
out_name_list.append(output_dict['name'])
if 'size' in output_dict.keys():
out_size_expr_list.append(output_dict['size'])
......
......@@ -115,7 +115,7 @@ op->SetInput("{{input.fluid_name | to_pascal}}", {std::static_pointer_cast<prim:
{%- if output.typename is tensor_sequence -%} {#- render the output of type std::Vector<Tensor> -#}
std::vector<Tensor> {{output.name}};
std::vector<std::string> {{output.name}}_names;
for (auto i=0; i<{{output.size}}; i++) {
for (size_t i=0; i<{{output.size}}; i++) {
auto tmp = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
{{output.name}}.push_back(tmp);
{{output.name}}_names.push_back(std::static_pointer_cast<prim::DescTensor>(tmp.impl())->Name());
......
......@@ -19,12 +19,6 @@
namespace paddle {
namespace prim {
template <>
Tensor reshape<Tensor>(const Tensor& x, const IntArray& shape) {
VLOG(4) << "Eager Prim API reshape_ad_func call";
return ::reshape_ad_func(x, shape);
}
template <>
Tensor full<Tensor>(const IntArray& shape,
const Scalar& value,
......@@ -34,14 +28,6 @@ Tensor full<Tensor>(const IntArray& shape,
return ::full_ad_func(shape, value, dtype, place);
}
template <>
std::vector<Tensor> split<Tensor>(const Tensor& x,
const IntArray& sections,
const Scalar& axis) {
VLOG(4) << "Eager Prim API split_ad_func call";
return ::split_ad_func(x, sections, axis);
}
template <>
Tensor cast<Tensor>(const Tensor& x, DataType dtype) {
return ::cast_ad_func(x, dtype);
......
......@@ -29,20 +29,12 @@ using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;
template <typename T>
Tensor reshape(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor full(const IntArray& shape,
const Scalar& value,
DataType dtype = DataType::FLOAT32,
const Place& place = CPUPlace());
template <typename T>
std::vector<Tensor> split(const Tensor& x,
const IntArray& sections,
const Scalar& axis);
template <typename T>
Tensor cast(const Tensor& x, DataType dtype);
......
......@@ -37,24 +37,6 @@
namespace paddle {
namespace prim {
template <>
Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
// TODO(cxxly): move to auto generate dir.
op->SetType("reshape2");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("shape", unsafe_vector_cast<int64_t, int>(shape.GetData()));
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor full<DescTensor>(const IntArray& shape,
const Scalar& value,
......@@ -127,32 +109,6 @@ Tensor full<DescTensor>(const IntArray& shape,
return out;
}
template <>
std::vector<Tensor> split<DescTensor>(const Tensor& x,
const IntArray& sections,
const Scalar& axis) {
int elem_num = sections.size();
std::vector<std::string> outs_name;
std::vector<Tensor> outs;
for (int i = 0; i < elem_num; ++i) {
Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
std::string out_name =
std::static_pointer_cast<prim::DescTensor>(out.impl())->Name();
outs_name.push_back(std::move(out_name));
outs.push_back(out);
}
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("split");
op->SetAttr("sections", sections.GetData());
op->SetAttr("axis", axis.to<int>());
op->SetOutput("Out", outs_name);
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return outs;
}
template <>
Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
Tensor out = empty<DescTensor>({}, DataType::FLOAT32, paddle::Place());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册