未验证 提交 32d9beef 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Elementwise_div Kernel Refactor (#37418)

* elementwise_div refactor

* fix compile bugs in windows ci
上级 c5ad3d06
......@@ -22,31 +22,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
struct SameDimsElemwiseDiv<
platform::CPUDeviceContext, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
blas.VDIV(x->numel(), x->data<T>(), y->data<T>(), z->data<T>());
}
};
// use default div function for int32/int64 type because of divison zero
// checking.
template <typename T>
struct SameDimsElemwiseDiv<
platform::CPUDeviceContext, T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
default_elementwise_div<platform::CPUDeviceContext, T>(ctx, x, y, z);
}
};
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Div"; }
......
......@@ -23,22 +23,6 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
class ElementwiseDivKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, DivFunctor<T>());
}
};
template <typename T>
static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
const T* out,
......
......@@ -23,6 +23,12 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/framework/pten_utils.h"
// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
namespace paddle {
namespace operators {
......@@ -42,13 +48,6 @@ void default_elementwise_div(const framework::ExecutionContext& ctx,
}
}
template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseDiv {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z);
};
template <typename DeviceContext, typename T>
class ElementwiseDivKernel : public framework::OpKernel<T> {
public:
......@@ -58,13 +57,13 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x->dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseDiv<DeviceContext, T> same_dims_div;
same_dims_div(ctx, x, y, z);
} else {
default_elementwise_div<DeviceContext, T>(ctx, x, y, z);
}
auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::ElementwiseDiv<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
}
};
......
......@@ -154,6 +154,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
{"axis"}, {"Out"});
}
}
if (Type() == "elementwise_div") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature("elementwise_div", {"X", "Y"},
{"axis"}, {"Out"});
}
}
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
};
......
......@@ -26,5 +26,7 @@ PD_DLL_DECL Tensor mean(const Tensor& x);
PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y);
PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y);
PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y);
} // namespace experimental
} // namespace paddle
......@@ -137,6 +137,41 @@ PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) {
return out;
}
PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"elementwise_div", kernel_key);
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
auto dense_y = std::dynamic_pointer_cast<pten::DenseTensor>(y.impl());
kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(-1);
// 4. InferShape
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1);
// 5. Prepare outputs
Tensor out;
const auto allocator = std::make_shared<DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);
// 6. Call kernel
kernel(&kernel_context);
return out;
}
} // namespace experimental
} // namespace paddle
......
......@@ -75,7 +75,7 @@ DenseTensor Scale(const ContextT& dev_ctx,
}
template <typename T, typename ContextT>
DenseTensor ElementwiseAdd(const ContextT& dev_ctx,
DenseTensor Add(const ContextT& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
......@@ -102,4 +102,17 @@ DenseTensor Subtract(const ContextT& dev_ctx,
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Divide(const ContextT& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
ElementwiseDiv<T>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -114,6 +114,30 @@ void ElementwiseSub(const CPUContext& dev_ctx,
}
}
template <typename T>
void ElementwiseDiv(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
// allocate memory for out
out->mutable_data<T>();
if (x.dims() == y.dims() && std::is_floating_point<T>::value) {
SameDimsElementwiseCompute<general::SameDimsDivFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseCompute<general::DivFunctor<T>, T>(
dev_ctx, x, y, axis, general::DivFunctor<T>(), out);
} else {
ElementwiseCompute<general::InverseDivFunctor<T>, T>(
dev_ctx, x, y, axis, general::InverseDivFunctor<T>(), out);
}
}
}
} // namespace pten
// TODO(chenweihang): replace by better impl
......@@ -174,3 +198,13 @@ PT_REGISTER_KERNEL("elementwise_sub",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("elementwise_div",
CPU,
ANY,
pten::ElementwiseDiv,
float,
double,
int,
int64_t,
complex64,
complex128) {}
......@@ -60,4 +60,10 @@ void ElementwiseSub(const CPUContext& dev_ctx,
int axis,
DenseTensor* out);
template <typename T>
void ElementwiseDiv(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace pten
......@@ -158,6 +158,23 @@ void ElementwiseSub(const CUDAContext& dev_ctx,
dev_ctx, inputs, &outputs, axis, general::SubFunctor<T>());
}
template <typename T>
void ElementwiseDiv(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
inputs.emplace_back(&y);
// allocate memory for out
out->mutable_data<T>();
outputs.emplace_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, general::DivFunctor<T>());
}
} // namespace pten
// TODO(chenweihang): replace by better impl
......@@ -217,3 +234,14 @@ PT_REGISTER_KERNEL("elementwise_sub",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("elementwise_div",
CUDA,
ANY,
pten::ElementwiseDiv,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
......@@ -62,6 +62,13 @@ void ElementwiseSub(const CUDAContext& dev_ctx,
int axis,
DenseTensor* out);
template <typename T>
void ElementwiseDiv(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace pten
#endif
......@@ -38,5 +38,14 @@ void ElementwiseSub(const DevCtx& dev_ctx,
blas.VSUB(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
template <typename DevCtx, typename T>
void ElementwiseDiv(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VDIV(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
} // namespace blas
} // namespace pten
......@@ -114,5 +114,65 @@ struct InverseSubFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b - a; }
};
// Divide
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsDivFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsDivFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
paddle::platform::errors::InvalidArgument(
"If use SameDimsDivFunctor, template args(T) must be floating point. ");
}
};
template <typename DevCtx, typename T>
struct SameDimsDivFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
blas::ElementwiseDiv<DevCtx, T>(dev_ctx, x, y, z);
}
};
#define DIV_ERROR_INFO \
"InvalidArgumentError: Integer division by zero encountered in " \
"(floor) divide. Please check the input value."
template <typename T, typename Enable = void>
struct DivFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; }
};
template <typename T>
struct DivFunctor<T,
typename std::enable_if<std::is_integral<T>::value>::type> {
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
// For int32/int64, need to check whether the divison is zero.
PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO);
return a / b;
}
};
template <typename T, typename Enable = void>
struct InverseDivFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b / a; }
};
} // namespace general
} // namespace pten
......@@ -131,3 +131,57 @@ TEST(API, subtract) {
ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f);
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
// TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(API, divide) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x->mutable_data<float>();
auto dense_y = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({10}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y->mutable_data<float>();
float div[3][10] = {0.0};
for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
div[i][j] = (i * 10 + j) * 1.0 / (j * 2.0 + 1);
}
}
for (size_t i = 0; i < 10; ++i) {
dense_y_data[i] = i * 2.0 + 1;
}
paddle::experimental::Tensor x(dense_x);
paddle::experimental::Tensor y(dense_y);
// 2. test API
auto out = paddle::experimental::divide(x, y);
// 3. check result
ASSERT_EQ(out.shape().size(), 2UL);
ASSERT_EQ(out.shape()[0], 3);
ASSERT_EQ(out.numel(), 30);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
auto expect_result = div;
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
auto actual_result0 = dense_out->data<float>()[0];
auto actual_result1 = dense_out->data<float>()[1];
auto actual_result2 = dense_out->data<float>()[10];
ASSERT_NEAR(expect_result[0][0], actual_result0, 1e-6f);
ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f);
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
......@@ -24,7 +24,7 @@ limitations under the License. */
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
TEST(DEV_API, elementwise_add) {
TEST(DEV_API, add) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
......@@ -56,7 +56,7 @@ TEST(DEV_API, elementwise_add) {
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto dense_out = pten::ElementwiseAdd<float>(
auto dense_out = pten::Add<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y,
......@@ -129,3 +129,56 @@ TEST(DEV_API, subtract) {
ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f);
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
TEST(DEV_API, divide) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::DenseTensor dense_x(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x.mutable_data<float>();
pten::DenseTensor dense_y(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({10}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y.mutable_data<float>();
float div[3][10] = {0.0};
for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
div[i][j] = (i * 10 + j) * 1.0 / (j * 2.0 + 1);
}
}
for (size_t i = 0; i < 10; ++i) {
dense_y_data[i] = i * 2.0 + 1;
}
int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto dense_out = pten::Divide<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y,
axis);
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
ASSERT_EQ(dense_out.dims()[0], 3);
ASSERT_EQ(dense_out.meta().dtype, pten::DataType::FLOAT32);
ASSERT_EQ(dense_out.meta().layout, pten::DataLayout::NCHW);
auto expect_result = div;
auto actual_result0 = dense_out.data<float>()[0];
auto actual_result1 = dense_out.data<float>()[1];
auto actual_result2 = dense_out.data<float>()[10];
ASSERT_NEAR(expect_result[0][0], actual_result0, 1e-6f);
ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f);
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册