未验证 提交 98e96853 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Seperate xshape kernel from normal kernel (#44315)

* seperate xshape kernel from normal kernel

* fix bugs in infermeta

* fix compile bugs

* fix compile bugs
上级 15dd94ab
...@@ -106,7 +106,7 @@ namespace ops = paddle::operators; ...@@ -106,7 +106,7 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(einsum, DECLARE_INFER_SHAPE_FUNCTOR(einsum,
EinsumInferShapeFunctor, EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferMeta)); PD_INFER_META(phi::EinsumRawInferMeta));
REGISTER_OPERATOR(einsum, REGISTER_OPERATOR(einsum,
ops::EinsumOp, ops::EinsumOp,
......
...@@ -347,7 +347,7 @@ namespace ops = paddle::operators; ...@@ -347,7 +347,7 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(squeeze2, DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
SqueezeInferShapeFunctor, SqueezeInferShapeFunctor,
PD_INFER_META(phi::SqueezeInferMeta)); PD_INFER_META(phi::SqueezeWithXShapeInferMeta));
REGISTER_OPERATOR(squeeze, REGISTER_OPERATOR(squeeze,
ops::SqueezeOp, ops::SqueezeOp,
......
...@@ -347,7 +347,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X"); ...@@ -347,7 +347,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2, DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2,
Unsqueeze2InferShapeFunctor, Unsqueeze2InferShapeFunctor,
PD_INFER_META(phi::UnsqueezeInferMeta)); PD_INFER_META(phi::UnsqueezeWithXShapeInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze, REGISTER_OPERATOR(unsqueeze,
......
...@@ -325,8 +325,8 @@ add_custom_command( ...@@ -325,8 +325,8 @@ add_custom_command(
${dygraph_api_header_file} ${dygraph_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp}
${dygraph_api_source_file} ${dygraph_api_source_file}
DEPENDS ${api_yaml_file} ${sparse_api_yaml_file} ${im_api_gen_file} DEPENDS ${api_yaml_file} ${legacy_api_yaml_file} ${sparse_api_yaml_file}
${api_gen_base} ${api_gen_file} ${im_api_gen_file} ${api_gen_base} ${api_gen_file}
VERBATIM) VERBATIM)
# generate wrapped infermeta # generate wrapped infermeta
......
...@@ -582,10 +582,10 @@ ...@@ -582,10 +582,10 @@
args : (Tensor[] x, str equation) args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta : infer_meta :
func : EinsumInferMeta func : EinsumRawInferMeta
param : [x, equation] param : [x, equation]
kernel : kernel :
func : einsum func : einsum_raw
backward : einsum_grad backward : einsum_grad
- api : elementwise_pow - api : elementwise_pow
...@@ -2047,9 +2047,9 @@ ...@@ -2047,9 +2047,9 @@
args : (Tensor x, int[] axes) args : (Tensor x, int[] axes)
output : Tensor(out), Tensor(xshape) output : Tensor(out), Tensor(xshape)
infer_meta : infer_meta :
func : SqueezeInferMeta func : SqueezeWithXShapeInferMeta
kernel : kernel :
func : squeeze func : squeeze_with_xshape
view: (x -> out) view: (x -> out)
intermediate : xshape intermediate : xshape
backward : squeeze_grad backward : squeeze_grad
...@@ -2290,9 +2290,9 @@ ...@@ -2290,9 +2290,9 @@
args : (Tensor x, IntArray axis) args : (Tensor x, IntArray axis)
output : Tensor(out), Tensor(xshape) output : Tensor(out), Tensor(xshape)
infer_meta : infer_meta :
func : UnsqueezeInferMeta func : UnsqueezeWithXShapeInferMeta
kernel : kernel :
func : unsqueeze func : unsqueeze_with_xshape
view: (x -> out) view: (x -> out)
intermediate : xshape intermediate : xshape
backward : unsqueeze_grad backward : unsqueeze_grad
......
...@@ -570,9 +570,7 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { ...@@ -570,9 +570,7 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out, MetaTensor* out) {
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
// collect the following informations to prepare einsum. // collect the following informations to prepare einsum.
LabelMap labelshape(0); LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction); LabelMap labeltype(LabelType::Reduction);
...@@ -609,6 +607,14 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, ...@@ -609,6 +607,14 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims)); out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype()); out->set_dtype(inputs[0]->dtype());
}
void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
EinsumInferMeta(inputs, equation, out);
for (size_t i = 0; i < xshape.size(); ++i) { for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) { if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims()); xshape[i]->set_dims(inputs[i]->dims());
...@@ -2448,8 +2454,7 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -2448,8 +2454,7 @@ void SplitInferMeta(const MetaTensor& x,
void SqueezeInferMeta(const MetaTensor& x, void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
MetaTensor* out, MetaTensor* out) {
MetaTensor* xshape) {
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit. // Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(), PADDLE_ENFORCE_LE(x_dims.size(),
...@@ -2469,15 +2474,25 @@ void SqueezeInferMeta(const MetaTensor& x, ...@@ -2469,15 +2474,25 @@ void SqueezeInferMeta(const MetaTensor& x,
out->share_lod(x); out->share_lod(x);
} }
out->set_dtype(x.dtype());
}
void SqueezeWithXShapeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape) {
SqueezeInferMeta(x, axes, out);
const auto& x_dims = x.dims();
std::vector<int64_t> xshape_dims(x_dims.size() + 1); std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0; xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) { for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i]; xshape_dims[i + 1] = x_dims[i];
} }
xshape->set_dims(phi::make_ddim(xshape_dims)); if (xshape) {
xshape->share_lod(x); xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->set_dtype(x.dtype()); xshape->share_lod(x);
out->set_dtype(x.dtype()); xshape->set_dtype(x.dtype());
}
} }
void StridedSliceRawInferMeta(const MetaTensor& x, void StridedSliceRawInferMeta(const MetaTensor& x,
...@@ -3310,7 +3325,6 @@ void UniqueRawInferMeta(const MetaTensor& x, ...@@ -3310,7 +3325,6 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes, const IntArray& axes,
MetaTensor* out, MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) { MetaConfig config) {
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6). // Validity Check: input tensor dims (<6).
...@@ -3339,14 +3353,22 @@ void UnsqueezeInferMeta(const MetaTensor& x, ...@@ -3339,14 +3353,22 @@ void UnsqueezeInferMeta(const MetaTensor& x,
} }
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
if (xshape) { }
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
const auto& x_dims = x.dims();
UnsqueezeInferMeta(x, axes, out, config);
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
if (xshape) {
xshape->set_dims(phi::make_ddim(xshape_dims)); xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x); xshape->share_lod(x);
xshape->set_dtype(x.dtype()); xshape->set_dtype(x.dtype());
......
...@@ -97,9 +97,13 @@ void EigvalsInferMeta(const MetaTensor& x, ...@@ -97,9 +97,13 @@ void EigvalsInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out, MetaTensor* out);
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape); void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);
void ExpandInferMeta(const MetaTensor& x, void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape, const IntArray& shape,
...@@ -341,8 +345,12 @@ void SplitInferMeta(const MetaTensor& x_meta, ...@@ -341,8 +345,12 @@ void SplitInferMeta(const MetaTensor& x_meta,
void SqueezeInferMeta(const MetaTensor& x, void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
MetaTensor* out, MetaTensor* out);
MetaTensor* xshape);
void SqueezeWithXShapeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape);
void StridedSliceRawInferMeta(const MetaTensor& x, void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
...@@ -470,9 +478,14 @@ void UniqueRawInferMeta(const MetaTensor& x, ...@@ -470,9 +478,14 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes, const IntArray& axes,
MetaTensor* out, MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());
void UnStackInferMeta(const MetaTensor& x, void UnStackInferMeta(const MetaTensor& x,
int axis, int axis,
int num, int num,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum, PD_REGISTER_KERNEL(einsum_raw,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EinsumKernelRaw, phi::EinsumKernelRaw,
...@@ -26,3 +26,12 @@ PD_REGISTER_KERNEL(einsum, ...@@ -26,3 +26,12 @@ PD_REGISTER_KERNEL(einsum,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(einsum,
CPU,
ALL_LAYOUT,
phi::EinsumKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -32,3 +32,18 @@ PD_REGISTER_KERNEL(squeeze, ...@@ -32,3 +32,18 @@ PD_REGISTER_KERNEL(squeeze,
int64_t, int64_t,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
CPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(unsqueeze, ...@@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(unsqueeze,
int64_t, int64_t,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape,
CPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum, PD_REGISTER_KERNEL(einsum_raw,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EinsumKernelRaw, phi::EinsumKernelRaw,
...@@ -28,3 +28,14 @@ PD_REGISTER_KERNEL(einsum, ...@@ -28,3 +28,14 @@ PD_REGISTER_KERNEL(einsum,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(einsum,
GPU,
ALL_LAYOUT,
phi::EinsumKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(squeeze, ...@@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(squeeze,
int64_t, int64_t,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
GPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -34,3 +34,20 @@ PD_REGISTER_KERNEL(unsqueeze, ...@@ -34,3 +34,20 @@ PD_REGISTER_KERNEL(unsqueeze,
int64_t, int64_t,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(unsqueeze_with_xshape,
GPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/expand_as_kernel.h" #include "paddle/phi/kernels/expand_as_kernel.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h" #include "paddle/phi/kernels/funcs/matrix_solve.h"
...@@ -77,7 +79,7 @@ static std::vector<int64_t> get_broadcast_batch_portion( ...@@ -77,7 +79,7 @@ static std::vector<int64_t> get_broadcast_batch_portion(
static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) { static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) {
std::vector<int> ret; std::vector<int> ret;
for (size_t i = 0; i < a.size(); i++) { for (size_t i = 0; i < a.size(); i++) {
ret.emplace_back(int(a[i])); ret.emplace_back(static_cast<int>(a[i]));
} }
return ret; return ret;
...@@ -167,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx, ...@@ -167,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims()); out_tmp.Resize(out->dims());
out_tmp = *out; out_tmp = *out;
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out, nullptr); phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1], x_dim[x_dim_size - 1],
......
...@@ -22,8 +22,7 @@ template <typename T, typename Context> ...@@ -22,8 +22,7 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx, void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
DenseTensor* out, DenseTensor* out) {
DenseTensor* xshape) {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true); auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true);
...@@ -31,4 +30,14 @@ void SqueezeKernel(const Context& dev_ctx, ...@@ -31,4 +30,14 @@ void SqueezeKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); out->Resize(out_dims);
} }
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* out,
DenseTensor* xshape) {
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi } // namespace phi
...@@ -22,8 +22,7 @@ template <typename T, typename Context> ...@@ -22,8 +22,7 @@ template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx, void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out, DenseTensor* out) {
DenseTensor* xshape) {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto out_dims = out->dims(); auto out_dims = out->dims();
if (axes.FromTensor()) { if (axes.FromTensor()) {
...@@ -39,4 +38,13 @@ void UnsqueezeKernel(const Context& dev_ctx, ...@@ -39,4 +38,13 @@ void UnsqueezeKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims. out->Resize(out_dims); // copy will reset the dims.
} }
template <typename T, typename Context>
void UnsqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi } // namespace phi
...@@ -23,6 +23,13 @@ template <typename T, typename Context> ...@@ -23,6 +23,13 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx, void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
DenseTensor* out, DenseTensor* out);
DenseTensor* xshape);
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* out,
DenseTensor* xshape);
} // namespace phi } // namespace phi
...@@ -25,8 +25,14 @@ template <typename T, typename Context> ...@@ -25,8 +25,14 @@ template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx, void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out, DenseTensor* out);
DenseTensor* xshape);
template <typename T, typename Context>
void UnsqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context> template <typename T, typename Context>
void Unsqueeze(const Context& dev_ctx, void Unsqueeze(const Context& dev_ctx,
...@@ -35,8 +41,8 @@ void Unsqueeze(const Context& dev_ctx, ...@@ -35,8 +41,8 @@ void Unsqueeze(const Context& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
MetaTensor meta_out(out); MetaTensor meta_out(out);
UnsqueezeInferMeta(x, axes, &meta_out, nullptr, MetaConfig()); UnsqueezeInferMeta(x, axes, &meta_out);
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out, nullptr); UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out);
} }
} // namespace phi } // namespace phi
...@@ -17,8 +17,14 @@ limitations under the License. */ ...@@ -17,8 +17,14 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( if (ctx.OutputSize("XShape") > 0 && ctx.OutputSize("InnerCache") > 0) {
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"}); return KernelSignature("einsum_raw",
{"Operands"},
{"equation"},
{"Out", "InnerCache", "XShape"});
} else {
return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"});
}
} }
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
...@@ -18,7 +18,12 @@ ...@@ -18,7 +18,12 @@
namespace phi { namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out", "XShape"}); if (ctx.HasOutput("XShape")) {
return KernelSignature(
"squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
} else {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out"});
}
} }
KernelSignature SqueezeGradOpArgumentMapping( KernelSignature SqueezeGradOpArgumentMapping(
......
...@@ -18,17 +18,33 @@ ...@@ -18,17 +18,33 @@
namespace phi { namespace phi {
KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("AxesTensorList") > 0) { if (ctx.HasOutput("XShape")) {
VLOG(2) << "unsqueeze2 in AxesTensorList"; if (ctx.InputSize("AxesTensorList") > 0) {
return KernelSignature( VLOG(2) << "unsqueeze2 in AxesTensorList";
"unsqueeze", {"X"}, {"AxesTensorList"}, {"Out", "XShape"}); return KernelSignature("unsqueeze_with_xshape",
} else if (ctx.InputSize("AxesTensor") > 0) { {"X"},
VLOG(2) << "unsqueeze2 in AxesTensor"; {"AxesTensorList"},
return KernelSignature( {"Out", "XShape"});
"unsqueeze", {"X"}, {"AxesTensor"}, {"Out", "XShape"}); } else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
} else { } else {
VLOG(2) << "unsqueeze2 in axes"; if (ctx.InputSize("AxesTensorList") > 0) {
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out", "XShape"}); VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensorList"}, {"Out"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensor"}, {"Out"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out"});
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册