未验证 提交 c5e857d4 编写于 作者: Y YuanRisheng 提交者: GitHub

elementwise_mul refactor (#37471)

* elementwise_mul refactor

* perfect code in test

* delete redundant code

* fix bugs when run test_multiply

* adjust the location of macro

* fix bugs when run ci
上级 0f24de83
......@@ -17,6 +17,10 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
......@@ -28,15 +32,39 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor x_for_selectedrows;
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_EQ(x_var != nullptr, true,
platform::errors::InvalidArgument(
"Cannot get input Variable X, Variable name = %s.",
ctx.InputName("X")));
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, MulFunctor<T>());
if (x_var->IsType<framework::SelectedRows>()) {
framework::Tensor x_for_selectedrows;
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
int axis =
PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, MulFunctor<T>());
} else if (x_var->IsType<framework::LoDTensor>()) {
auto* x_lod = ctx.Input<framework::LoDTensor>("X");
auto* y_lod = ctx.Input<framework::LoDTensor>("Y");
auto* z_lod = ctx.Output<framework::LoDTensor>("Out");
z_lod->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y_lod);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
pten::ElementwiseMul<T>(cuda_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}
}
};
......
......@@ -15,11 +15,16 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cpu_info.h"
// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
namespace paddle {
namespace operators {
......@@ -106,24 +111,32 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x.dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
same_dims_mul(ctx, &x, y, z);
} else {
default_elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
}
} else if (x_var->IsType<framework::LoDTensor>()) {
x = x_var->Get<framework::LoDTensor>();
z = ctx.Output<framework::LoDTensor>("Out");
auto* x_lod = ctx.Input<framework::LoDTensor>("X");
auto* z_lod = ctx.Output<framework::LoDTensor>("Out");
z_lod->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
pten::ElementwiseMul<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}
z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x.dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
same_dims_mul(ctx, &x, y, z);
} else {
default_elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
}
}
};
template <typename T>
......
......@@ -160,6 +160,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
{"axis"}, {"Out"});
}
}
if (Type() == "elementwise_mul") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature("elementwise_mul", {"X", "Y"},
{"axis"}, {"Out"});
}
}
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
};
......
......@@ -28,5 +28,8 @@ PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y);
PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y);
PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y);
PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y);
} // namespace experimental
} // namespace paddle
......@@ -172,6 +172,41 @@ PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) {
return out;
}
PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"elementwise_mul", kernel_key);
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
auto dense_y = std::dynamic_pointer_cast<pten::DenseTensor>(y.impl());
kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(-1);
// 4. InferShape
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1);
// 5. Prepare outputs
Tensor out;
const auto allocator = std::make_shared<DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);
// 6. Call kernel
kernel(&kernel_context);
return out;
}
} // namespace experimental
} // namespace paddle
......
......@@ -234,9 +234,14 @@ void ReMakePtenDenseTensorFromVar(const framework::Variable& variable,
const pten::TensorArgDef& arg_def,
pten::DenseTensor* dst) {
auto expected_place = pten::TransToFluidPlace(arg_def.backend);
if (variable.IsType<framework::LoDTensor>()) {
const auto& tensor = variable.Get<framework::LoDTensor>();
// check input dtype before ReMakePtenDenseTensor
PADDLE_ENFORCE(
(arg_def.dtype == pten::TransToPtenDataType(tensor.type())),
paddle::platform::errors::InvalidArgument(
"The type of input data is diffrent from the type of the "
"argument's definition in kernel."));
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
......@@ -248,6 +253,11 @@ void ReMakePtenDenseTensorFromVar(const framework::Variable& variable,
// TODO(chenweihang): now we don't deal with row and height
// by xiaowei's advice
const auto& tensor = variable.Get<framework::SelectedRows>();
PADDLE_ENFORCE(
(arg_def.dtype == pten::TransToPtenDataType(tensor.value().type())),
paddle::platform::errors::InvalidArgument(
"The type of input data is diffrent from the type of the "
"argument's definition in kernel."));
if (!platform::is_same_place(tensor.value().place(), expected_place)) {
framework::Tensor tmp_tensor;
TensorCopySync(tensor.value(), expected_place, &tmp_tensor);
......
......@@ -115,4 +115,18 @@ DenseTensor Divide(const ContextT& dev_ctx,
ElementwiseDiv<T>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Multiply(const ContextT& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
ElementwiseMul<T>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -64,56 +64,6 @@ void ScaleHost(const CPUContext& dev_ctx,
out);
}
template <typename T>
void ElementwiseAdd(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
// allocate memory for out
out->mutable_data<T>();
if (x.dims() == y.dims()) {
SameDimsElementwiseCompute<general::SameDimsAddFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseCompute<general::AddFunctor<T>, T>(
dev_ctx, x, y, axis, general::AddFunctor<T>(), out);
} else {
ElementwiseCompute<general::InverseAddFunctor<T>, T>(
dev_ctx, x, y, axis, general::InverseAddFunctor<T>(), out);
}
}
}
template <typename T>
void ElementwiseSub(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
// allocate memory for out
out->mutable_data<T>();
if (x.dims() == y.dims()) {
SameDimsElementwiseCompute<general::SameDimsSubFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseCompute<general::SubFunctor<T>, T>(
dev_ctx, x, y, axis, general::SubFunctor<T>(), out);
} else {
ElementwiseCompute<general::InverseSubFunctor<T>, T>(
dev_ctx, x, y, axis, general::InverseSubFunctor<T>(), out);
}
}
}
template <typename T>
void ElementwiseDiv(const CPUContext& dev_ctx,
const DenseTensor& x,
......@@ -138,6 +88,15 @@ void ElementwiseDiv(const CPUContext& dev_ctx,
}
}
// Create the definition of ElementwiseAdd
DEFINE_CPU_ELEMENTWISE_OP(Add)
// Create the definition of ElementwiseSub
DEFINE_CPU_ELEMENTWISE_OP(Sub)
// Create the definition of ElementwiseMul
DEFINE_CPU_ELEMENTWISE_OP(Mul)
} // namespace pten
// TODO(chenweihang): replace by better impl
......@@ -208,3 +167,14 @@ PT_REGISTER_KERNEL("elementwise_div",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("elementwise_mul",
CPU,
ANY,
pten::ElementwiseMul,
float,
double,
int,
int64_t,
bool,
complex64,
complex128) {}
......@@ -66,4 +66,36 @@ void ElementwiseDiv(const CPUContext& dev_ctx,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T>
void ElementwiseMul(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace pten
#define DEFINE_CPU_ELEMENTWISE_OP(name) \
template <typename T> \
void Elementwise##name(const CPUContext& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
int axis, \
DenseTensor* out) { \
out->mutable_data<T>(); \
if (x.dims() == y.dims()) { \
SameDimsElementwiseCompute< \
general::SameDims##name##Functor<CPUContext, T>>()( \
dev_ctx, x, y, out); \
} else { \
auto x_dims = x.dims(); \
auto y_dims = y.dims(); \
if (x_dims.size() >= y_dims.size()) { \
ElementwiseCompute<general::name##Functor<T>, T>( \
dev_ctx, x, y, axis, general::name##Functor<T>(), out); \
} else { \
ElementwiseCompute<general::Inverse##name##Functor<T>, T>( \
dev_ctx, x, y, axis, general::Inverse##name##Functor<T>(), out); \
} \
} \
}
......@@ -124,56 +124,14 @@ void ScaleHost(const CUDAContext& dev_ctx,
out);
}
template <typename T>
void ElementwiseAdd(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
inputs.emplace_back(&y);
// allocate memory for out
out->mutable_data<T>();
outputs.emplace_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, general::AddFunctor<T>());
}
template <typename T>
void ElementwiseSub(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
inputs.emplace_back(&y);
// allocate memory for out
out->mutable_data<T>();
outputs.emplace_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, general::SubFunctor<T>());
}
template <typename T>
void ElementwiseDiv(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
inputs.emplace_back(&y);
// allocate memory for out
out->mutable_data<T>();
outputs.emplace_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, general::DivFunctor<T>());
}
// Create the definition of ElementwiseAdd
DEFINE_CUDA_ELEMENTWISE_OP(Add)
// Create the definition of ElementwiseSub
DEFINE_CUDA_ELEMENTWISE_OP(Sub)
// Create the definition of ElementwiseMul
DEFINE_CUDA_ELEMENTWISE_OP(Mul)
// Create the definition of ElementwiseDiv
DEFINE_CUDA_ELEMENTWISE_OP(Div)
} // namespace pten
......@@ -245,3 +203,15 @@ PT_REGISTER_KERNEL("elementwise_div",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("elementwise_mul",
CUDA,
ANY,
pten::ElementwiseMul,
float,
double,
int,
int64_t,
bool,
float16,
complex64,
complex128) {}
......@@ -69,6 +69,29 @@ void ElementwiseDiv(const CUDAContext& dev_ctx,
int axis,
DenseTensor* out);
template <typename T>
void ElementwiseMul(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace pten
#define DEFINE_CUDA_ELEMENTWISE_OP(name) \
template <typename T> \
void Elementwise##name(const CUDAContext& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
int axis, \
DenseTensor* out) { \
std::vector<const DenseTensor*> inputs; \
std::vector<DenseTensor*> outputs; \
inputs.emplace_back(&x); \
inputs.emplace_back(&y); \
outputs.emplace_back(out); \
out->mutable_data<T>(); \
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( \
dev_ctx, inputs, &outputs, axis, general::name##Functor<T>()); \
}
#endif
......@@ -47,5 +47,13 @@ void ElementwiseDiv(const DevCtx& dev_ctx,
blas.VDIV(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
template <typename DevCtx, typename T>
void ElementwiseMul(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VMUL(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
} // namespace blas
} // namespace pten
......@@ -45,5 +45,17 @@ void ElementwiseSub(const DevCtx& dev_ctx,
eigen_z.device(place) = eigen_x - eigen_y;
}
template <typename DevCtx, typename T>
void ElementwiseMul(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x * eigen_y;
}
} // namespace eigen
} // namespace pten
......@@ -174,5 +174,48 @@ struct InverseDivFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b / a; }
};
// Multiply
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsMulFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsMulFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
blas::ElementwiseMul<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename DevCtx, typename T>
struct SameDimsMulFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
eigen::ElementwiseMul<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename T>
struct MulFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};
template <typename T>
struct InverseMulFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b * a; }
};
} // namespace general
} // namespace pten
......@@ -164,6 +164,7 @@ TEST(API, divide) {
for (size_t i = 0; i < 10; ++i) {
dense_y_data[i] = i * 2.0 + 1;
}
paddle::experimental::Tensor x(dense_x);
paddle::experimental::Tensor y(dense_y);
......@@ -189,5 +190,57 @@ TEST(API, divide) {
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
TEST(API, multiply) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x->mutable_data<float>();
auto dense_y = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({10}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y->mutable_data<float>();
float mul[3][10] = {0.0};
for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
mul[i][j] = (i * 10 + j) * 1.0 * j * 2.0;
}
}
for (size_t i = 0; i < 10; ++i) {
dense_y_data[i] = i * 2.0;
}
paddle::experimental::Tensor x(dense_x);
paddle::experimental::Tensor y(dense_y);
// 2. test API
auto out = paddle::experimental::multiply(x, y);
// 3. check result
ASSERT_EQ(out.shape().size(), 2UL);
ASSERT_EQ(out.shape()[0], 3);
ASSERT_EQ(out.numel(), 30);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
auto expect_result = mul;
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
auto actual_result0 = dense_out->data<float>()[0];
auto actual_result1 = dense_out->data<float>()[1];
auto actual_result2 = dense_out->data<float>()[10];
ASSERT_NEAR(expect_result[0][0], actual_result0, 1e-6f);
ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f);
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
} // namespace tests
} // namespace paddle
......@@ -68,8 +68,8 @@ TEST(DEV_API, add) {
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
ASSERT_EQ(dense_out.dims()[0], 3);
ASSERT_EQ(dense_out.meta().dtype, pten::DataType::FLOAT32);
ASSERT_EQ(dense_out.meta().layout, pten::DataLayout::NCHW);
ASSERT_EQ(dense_out.dtype(), pten::DataType::FLOAT32);
ASSERT_EQ(dense_out.layout(), pten::DataLayout::NCHW);
auto expect_result = sum;
auto actual_result0 = dense_out.data<float>()[0];
......@@ -174,8 +174,8 @@ TEST(DEV_API, divide) {
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
ASSERT_EQ(dense_out.dims()[0], 3);
ASSERT_EQ(dense_out.meta().dtype, pten::DataType::FLOAT32);
ASSERT_EQ(dense_out.meta().layout, pten::DataLayout::NCHW);
ASSERT_EQ(dense_out.dtype(), pten::DataType::FLOAT32);
ASSERT_EQ(dense_out.layout(), pten::DataLayout::NCHW);
auto expect_result = div;
auto actual_result0 = dense_out.data<float>()[0];
......@@ -186,5 +186,57 @@ TEST(DEV_API, divide) {
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
TEST(DEV_API, multiply) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::DenseTensor dense_x(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x.mutable_data<float>();
pten::DenseTensor dense_y(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({10}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y.mutable_data<float>();
float mul[3][10] = {0.0};
for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
mul[i][j] = (i * 10 + j) * 1.0 * j * 2.0;
}
}
for (size_t i = 0; i < 10; ++i) {
dense_y_data[i] = i * 2.0;
}
int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto dense_out = pten::Multiply<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y,
axis);
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
ASSERT_EQ(dense_out.dims()[0], 3);
ASSERT_EQ(dense_out.dtype(), pten::DataType::FLOAT32);
ASSERT_EQ(dense_out.layout(), pten::DataLayout::NCHW);
auto expect_result = mul;
auto actual_result0 = dense_out.data<float>()[0];
auto actual_result1 = dense_out.data<float>()[1];
auto actual_result2 = dense_out.data<float>()[10];
ASSERT_NEAR(expect_result[0][0], actual_result0, 1e-6f);
ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f);
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册