未验证 提交 98e96853 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Seperate xshape kernel from normal kernel (#44315)

* seperate xshape kernel from normal kernel

* fix bugs in infermeta

* fix compile bugs

* fix compile bugs
上级 15dd94ab
......@@ -106,7 +106,7 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(einsum,
EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferMeta));
PD_INFER_META(phi::EinsumRawInferMeta));
REGISTER_OPERATOR(einsum,
ops::EinsumOp,
......
......@@ -347,7 +347,7 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
SqueezeInferShapeFunctor,
PD_INFER_META(phi::SqueezeInferMeta));
PD_INFER_META(phi::SqueezeWithXShapeInferMeta));
REGISTER_OPERATOR(squeeze,
ops::SqueezeOp,
......
......@@ -347,7 +347,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2,
Unsqueeze2InferShapeFunctor,
PD_INFER_META(phi::UnsqueezeInferMeta));
PD_INFER_META(phi::UnsqueezeWithXShapeInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze,
......
......@@ -325,8 +325,8 @@ add_custom_command(
${dygraph_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp}
${dygraph_api_source_file}
DEPENDS ${api_yaml_file} ${sparse_api_yaml_file} ${im_api_gen_file}
${api_gen_base} ${api_gen_file}
DEPENDS ${api_yaml_file} ${legacy_api_yaml_file} ${sparse_api_yaml_file}
${im_api_gen_file} ${api_gen_base} ${api_gen_file}
VERBATIM)
# generate wrapped infermeta
......
......@@ -582,10 +582,10 @@
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta :
func : EinsumInferMeta
func : EinsumRawInferMeta
param : [x, equation]
kernel :
func : einsum
func : einsum_raw
backward : einsum_grad
- api : elementwise_pow
......@@ -2047,9 +2047,9 @@
args : (Tensor x, int[] axes)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : SqueezeInferMeta
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze
func : squeeze_with_xshape
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
......@@ -2290,9 +2290,9 @@
args : (Tensor x, IntArray axis)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : UnsqueezeInferMeta
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze
func : unsqueeze_with_xshape
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
......
......@@ -570,9 +570,7 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
......@@ -609,6 +607,14 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype());
}
void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
EinsumInferMeta(inputs, equation, out);
for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims());
......@@ -2448,8 +2454,7 @@ void SplitInferMeta(const MetaTensor& x,
void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape) {
MetaTensor* out) {
const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(),
......@@ -2469,15 +2474,25 @@ void SqueezeInferMeta(const MetaTensor& x,
out->share_lod(x);
}
out->set_dtype(x.dtype());
}
void SqueezeWithXShapeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape) {
SqueezeInferMeta(x, axes, out);
const auto& x_dims = x.dims();
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
out->set_dtype(x.dtype());
if (xshape) {
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
}
}
void StridedSliceRawInferMeta(const MetaTensor& x,
......@@ -3310,7 +3325,6 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6).
......@@ -3339,14 +3353,22 @@ void UnsqueezeInferMeta(const MetaTensor& x,
}
out->set_dtype(x.dtype());
}
if (xshape) {
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
}
void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
const auto& x_dims = x.dims();
UnsqueezeInferMeta(x, axes, out, config);
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
if (xshape) {
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
......
......@@ -97,9 +97,13 @@ void EigvalsInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);
MetaTensor* out);
void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);
void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
......@@ -341,8 +345,12 @@ void SplitInferMeta(const MetaTensor& x_meta,
void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape);
MetaTensor* out);
void SqueezeWithXShapeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape);
void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
......@@ -470,9 +478,14 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());
void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());
void UnStackInferMeta(const MetaTensor& x,
int axis,
int num,
......
......@@ -18,7 +18,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum,
PD_REGISTER_KERNEL(einsum_raw,
CPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
......@@ -26,3 +26,12 @@ PD_REGISTER_KERNEL(einsum,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(einsum,
CPU,
ALL_LAYOUT,
phi::EinsumKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -32,3 +32,18 @@ PD_REGISTER_KERNEL(squeeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
CPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(unsqueeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape,
CPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -18,7 +18,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum,
PD_REGISTER_KERNEL(einsum_raw,
GPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
......@@ -28,3 +28,14 @@ PD_REGISTER_KERNEL(einsum,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(einsum,
GPU,
ALL_LAYOUT,
phi::EinsumKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(squeeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
GPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -34,3 +34,20 @@ PD_REGISTER_KERNEL(unsqueeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape,
GPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/expand_as_kernel.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
......@@ -77,7 +79,7 @@ static std::vector<int64_t> get_broadcast_batch_portion(
static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) {
std::vector<int> ret;
for (size_t i = 0; i < a.size(); i++) {
ret.emplace_back(int(a[i]));
ret.emplace_back(static_cast<int>(a[i]));
}
return ret;
......@@ -167,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims());
out_tmp = *out;
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out, nullptr);
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else {
PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1],
......
......@@ -22,8 +22,7 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* out,
DenseTensor* xshape) {
DenseTensor* out) {
auto x_dims = x.dims();
auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true);
......@@ -31,4 +30,14 @@ void SqueezeKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims);
}
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* out,
DenseTensor* xshape) {
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
......@@ -22,8 +22,7 @@ template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
DenseTensor* out) {
auto x_dims = x.dims();
auto out_dims = out->dims();
if (axes.FromTensor()) {
......@@ -39,4 +38,13 @@ void UnsqueezeKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims.
}
template <typename T, typename Context>
void UnsqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
......@@ -23,6 +23,13 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* out,
DenseTensor* xshape);
DenseTensor* out);
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* out,
DenseTensor* xshape);
} // namespace phi
......@@ -25,8 +25,14 @@ template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape);
DenseTensor* out);
template <typename T, typename Context>
void UnsqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void Unsqueeze(const Context& dev_ctx,
......@@ -35,8 +41,8 @@ void Unsqueeze(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape) {
MetaTensor meta_out(out);
UnsqueezeInferMeta(x, axes, &meta_out, nullptr, MetaConfig());
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out, nullptr);
UnsqueezeInferMeta(x, axes, &meta_out);
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
......@@ -17,8 +17,14 @@ limitations under the License. */
namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"});
if (ctx.OutputSize("XShape") > 0 && ctx.OutputSize("InnerCache") > 0) {
return KernelSignature("einsum_raw",
{"Operands"},
{"equation"},
{"Out", "InnerCache", "XShape"});
} else {
return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"});
}
}
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
......@@ -18,7 +18,12 @@
namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out", "XShape"});
if (ctx.HasOutput("XShape")) {
return KernelSignature(
"squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
} else {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out"});
}
}
KernelSignature SqueezeGradOpArgumentMapping(
......
......@@ -18,17 +18,33 @@
namespace phi {
KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature(
"unsqueeze", {"X"}, {"AxesTensorList"}, {"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
if (ctx.HasOutput("XShape")) {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature("unsqueeze_with_xshape",
{"X"},
{"AxesTensorList"},
{"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out", "XShape"});
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensorList"}, {"Out"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensor"}, {"Out"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out"});
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册