未验证 提交 13f1641d 编写于 作者: Y YuanRisheng 提交者: GitHub

move elementwise_mul selected rows input (#41042)

上级 04325d2c
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
USE_OP(elementwise_mul); USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(elementwise_add_grad); USE_OP_ITSELF(elementwise_add_grad);
......
...@@ -104,4 +104,4 @@ TEST(elementwise_op, plugin) { ...@@ -104,4 +104,4 @@ TEST(elementwise_op, plugin) {
} // namespace paddle } // namespace paddle
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(elementwise_mul); USE_OP_ITSELF(elementwise_mul);
...@@ -20,35 +20,6 @@ limitations under the License. */ ...@@ -20,35 +20,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct SameDimsElemwiseMul<
platform::CPUDeviceContext, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(ctx);
blas.VMUL(x->numel(), x->data<T>(), y->data<T>(), z->data<T>());
}
};
template <typename T>
struct SameDimsElemwiseMul<
platform::CPUDeviceContext, T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
eigen_z.device(place) = eigen_x * eigen_y;
}
};
class ElementwiseMulOpMaker : public ElementwiseOpMaker { class ElementwiseMulOpMaker : public ElementwiseOpMaker {
protected: protected:
std::string GetName() const override { return "Mul"; } std::string GetName() const override { return "Mul"; }
...@@ -160,20 +131,6 @@ REGISTER_OPERATOR( ...@@ -160,20 +131,6 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(elementwise_mul_triple_grad, ops::ElementwiseOpTripleGrad); REGISTER_OPERATOR(elementwise_mul_triple_grad, ops::ElementwiseOpTripleGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_mul) REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint( .AddCheckpoint(
R"ROC(Register elementwise_mul for adding the attribute of Scale_y)ROC", R"ROC(Register elementwise_mul for adding the attribute of Scale_y)ROC",
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
class ElementwiseMulKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_EQ(x_var != nullptr, true,
platform::errors::InvalidArgument(
"Cannot get input Variable X, Variable name = %s.",
ctx.InputName("X")));
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
if (x_var->IsType<phi::SelectedRows>()) {
framework::Tensor x_for_selectedrows;
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
int axis =
PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
T, T>(
cuda_ctx, ins, &outs, axis, MulFunctor<T>());
} else if (x_var->IsType<framework::LoDTensor>()) {
auto* x_lod = ctx.Input<framework::LoDTensor>("X");
auto* y_lod = ctx.Input<framework::LoDTensor>("Y");
auto* z_lod = ctx.Output<framework::LoDTensor>("Out");
z_lod->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePhiDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePhiDenseTensor(*y_lod);
auto pt_z = paddle::experimental::MakePhiDenseTensor(*z_lod);
phi::MultiplyRawKernel<T>(static_cast<const phi::GPUContext&>(cuda_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_mul, ops::ElementwiseMulKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
...@@ -58,85 +58,5 @@ class ElementwiseMulOp : public ElementwiseOp { ...@@ -58,85 +58,5 @@ class ElementwiseMulOp : public ElementwiseOp {
} }
}; };
template <typename DeviceContext, typename T>
void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseMulFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseMulFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseMul {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z);
};
template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_EQ(x_var != nullptr, true,
platform::errors::InvalidArgument(
"Cannot get input Variable X, Variable name = %s.",
ctx.InputName("X")));
auto* y = ctx.Input<framework::LoDTensor>("Y");
framework::Tensor x, *z;
if (x_var->IsType<phi::SelectedRows>()) {
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
platform::errors::InvalidArgument(
"For elementwise_op, if X is Sparse, Y must be "
"scalar. But reveived the size of Y = %s.",
y->dims().size()));
auto& x_sele = x_var->Get<phi::SelectedRows>();
auto out_sele = ctx.Output<phi::SelectedRows>("Out");
x = x_sele.value();
out_sele->set_rows(x_sele.rows());
out_sele->set_height(x_sele.height());
out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
z = ctx.Output<phi::SelectedRows>("Out")->mutable_value();
z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x.dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
same_dims_mul(ctx, &x, y, z);
} else {
default_elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
}
} else if (x_var->IsType<framework::LoDTensor>()) {
auto* x_lod = ctx.Input<framework::LoDTensor>("X");
auto* z_lod = ctx.Output<framework::LoDTensor>("Out");
z_lod->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePhiDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePhiDenseTensor(*y);
auto pt_z = paddle::experimental::MakePhiDenseTensor(*z_lod);
phi::MultiplyRawKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(elementwise_mul); USE_OP_ITSELF(elementwise_mul);
USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN);
USE_OP_ITSELF(relu); USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
......
...@@ -202,6 +202,7 @@ PD_REGISTER_KERNEL(multiply, ...@@ -202,6 +202,7 @@ PD_REGISTER_KERNEL(multiply,
int64_t, int64_t,
bool, bool,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(maximum, PD_REGISTER_KERNEL(maximum,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/selected_rows/elementwise_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void MultiplyRawKernel(const Context& dev_ctx,
const SelectedRows& x,
const DenseTensor& y,
int axis,
SelectedRows* out) {
PADDLE_ENFORCE_EQ(y.dims().size() == 1 && y.dims()[0] == 1,
true,
phi::errors::InvalidArgument(
"For MultiplyKernel, if X is Sparse, Y must be "
"scalar. But reveived the size of Y = %s.",
y.dims().size()));
out->set_rows(x.rows());
out->set_height(x.height());
auto z = out->mutable_value();
z->Resize(x.value().dims());
dev_ctx.Alloc(z, x.value().dtype());
MultiplyRawKernel<T, Context>(dev_ctx, x.value(), y, axis, z);
}
template <typename T, typename Context>
void MultiplyKernel(const Context& dev_ctx,
const SelectedRows& x,
const DenseTensor& y,
SelectedRows* out) {
int axis = -1;
MultiplyRawKernel<T, Context>(dev_ctx, x, y, axis, out);
}
} // namespace sr
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(multiply_raw_sr,
CPU,
ALL_LAYOUT,
phi::sr::MultiplyRawKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
complex64,
complex128) {}
PD_REGISTER_KERNEL(multiply_sr,
CPU,
ALL_LAYOUT,
phi::sr::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
complex64,
complex128) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(multiply_raw_sr,
GPU,
ALL_LAYOUT,
phi::sr::MultiplyRawKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::float16,
complex64,
complex128) {}
PD_REGISTER_KERNEL(multiply_sr,
GPU,
ALL_LAYOUT,
phi::sr::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::float16,
complex64,
complex128) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void MultiplyRawKernel(const Context& dev_ctx,
const SelectedRows& x,
const DenseTensor& y,
int axis,
SelectedRows* out);
template <typename T, typename Context>
void MultiplyKernel(const Context& dev_ctx,
const SelectedRows& x,
const DenseTensor& y,
SelectedRows* out);
} // namespace sr
} // namespace phi
...@@ -42,8 +42,12 @@ KernelSignature ElementwiseMulOpArgumentMapping( ...@@ -42,8 +42,12 @@ KernelSignature ElementwiseMulOpArgumentMapping(
return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"}); return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
} }
return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
} else {
if (axis == -1) {
return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"});
} }
return KernelSignature("unregistered", {}, {}, {});
} }
KernelSignature ElementwiseDivOpArgumentMapping( KernelSignature ElementwiseDivOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册