未验证 提交 be817719 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Add dot and matmul grad kernel in pten (#38713)

* refactor matmul directory in pten

* fix merge conflict

* add dot_grad kernel

* add dot_grad kernel in pten

* add matmul_grad kernel

* update the code

* delete useless code in fluid

* fix some bug of running matmul grad kernel

* fix merge conflict

* refactor some code

* refactor code
上级 5b940c44
...@@ -79,6 +79,9 @@ function(kernel_library TARGET) ...@@ -79,6 +79,9 @@ function(kernel_library TARGET)
endif() endif()
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h) list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
endif()
list(APPEND all_srcs ${common_srcs}) list(APPEND all_srcs ${common_srcs})
list(APPEND all_srcs ${cpu_srcs}) list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs}) list(APPEND all_srcs ${gpu_srcs})
......
...@@ -1880,16 +1880,32 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1880,16 +1880,32 @@ void OperatorWithKernel::BuildPtenKernelContext(
// Otherwise,we will create new storage. // Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) { if (current_vector_size > start_idx + offset) {
experimental::ReMakePtenDenseTensorFromVar( auto* buffer_tensor =
outs_vector[offset], out_def,
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx + pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx +
offset)); offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset],
out_def, buffer_tensor);
}
} else { } else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange( pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset], experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def)); out_def));
} }
} }
// Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (outs_vector.empty()) {
if (current_vector_size > start_idx) {
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr});
}
end_idx = start_idx + 1;
}
pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx), pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx),
i); i);
} }
...@@ -2002,7 +2018,9 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const { ...@@ -2002,7 +2018,9 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const {
range_pair.first, range_pair.second); range_pair.first, range_pair.second);
for (size_t j = 0; j < pten_outs.size(); ++j) { for (size_t j = 0; j < pten_outs.size(); ++j) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]); if (pten_outs[j]) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
}
} }
} }
} }
......
...@@ -99,7 +99,7 @@ KernelSignatureMap& KernelSignatureMap::Instance() { ...@@ -99,7 +99,7 @@ KernelSignatureMap& KernelSignatureMap::Instance() {
const auto& op_type = pair.first; const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_; const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) && if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto != nullptr) { op_proto) {
KernelArgsNameMakerByOpProto maker(op_proto); KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type; VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_ auto success = kernel_signature_map_->map_
......
...@@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext( ...@@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext(
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
auto& out_def = output_defs.at(i); auto& out_def = output_defs.at(i);
auto& outs_vector = outs.at(output_names[i]);
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
size_t end_idx = start_idx + outs_vector.size();
auto current_vector_size = kernel_ctx->OutputsSize(); auto current_vector_size = kernel_ctx->OutputsSize();
auto iter = outs.find(output_names[i]);
if (iter == outs.end()) {
if (current_vector_size > start_idx) {
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i);
continue;
}
auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size();
// If the memory needed is less than the current memory allocated, we will // If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar. // reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage. // Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) { if (current_vector_size > start_idx + offset) {
experimental::ReMakePtenDenseTensorFromVar( auto* buffer_tensor =
outs_vector[offset]->MutableVar(), out_def, kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset);
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset)); if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def, buffer_tensor);
} else {
kernel_ctx->SetOutputWithoutSetRange(
start_idx + offset,
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
} else { } else {
kernel_ctx->EmplaceBackOutputWithoutSetRange( kernel_ctx->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar( experimental::MakePtenTensorBaseFromVar(
...@@ -465,15 +487,18 @@ static void WriteBackToOutputs( ...@@ -465,15 +487,18 @@ static void WriteBackToOutputs(
auto& output_names = std::get<2>(pt_kernel_signature.args); auto& output_names = std::get<2>(pt_kernel_signature.args);
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = outs.at(output_names[i]); auto iter = outs.find(output_names[i]);
if (iter != outs.end()) {
auto& outs_vector = iter->second;
auto& range_pair = kernel_ctx->OutputRangeAt(i); auto& range_pair = kernel_ctx->OutputRangeAt(i);
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>( auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second); range_pair.first, range_pair.second);
for (size_t j = 0; j < pten_outs.size(); ++j) { for (size_t j = 0; j < pten_outs.size(); ++j) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], experimental::MakeVariableFromPtenTensor(pten_outs[j],
outs_vector[j]->MutableVar()); outs_vector[j]->MutableVar());
}
} }
} }
} }
...@@ -529,6 +554,7 @@ static void PreparedOpRunImpl( ...@@ -529,6 +554,7 @@ static void PreparedOpRunImpl(
template <typename VarType> template <typename VarType>
static void PreparedOpRunPtImpl( static void PreparedOpRunPtImpl(
const framework::OperatorBase& op, const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature, const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context, const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins, platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
...@@ -558,7 +584,9 @@ static void PreparedOpRunPtImpl( ...@@ -558,7 +584,9 @@ static void PreparedOpRunPtImpl(
pt_kernel_context->ClearData(); pt_kernel_context->ClearData();
// TODO(chenweihang): add debug flags later // TODO(chenweihang): add debug flags later
// TODO(chenweihang): deal with complex cases later if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs);
}
} }
void PreparedOp::Run(const NameVarMap<VarBase>& ins, void PreparedOp::Run(const NameVarMap<VarBase>& ins,
...@@ -566,9 +594,9 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins, ...@@ -566,9 +594,9 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_pten_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, pt_kernel_signature_, pt_kernel_, PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_context_, dev_ctx_, ins, outs, attrs, pt_kernel_, pt_kernel_context_, dev_ctx_, ins,
default_attrs); outs, attrs, default_attrs);
} else { } else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs); outs, attrs, default_attrs);
...@@ -580,9 +608,9 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -580,9 +608,9 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_pten_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(op_, pt_kernel_signature_, pt_kernel_, PreparedOpRunPtImpl<VariableWrapper>(
pt_kernel_context_, dev_ctx_, ins, op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_,
outs, attrs, default_attrs); dev_ctx_, ins, outs, attrs, default_attrs);
} else { } else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_, PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs); ins, outs, attrs, default_attrs);
......
...@@ -39,7 +39,7 @@ class ConjKernel : public framework::OpKernel<T> { ...@@ -39,7 +39,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel // call new kernel
pten::Conj<T>(dev_ctx, *pt_x.get(), pt_out.get()); pten::ConjKernel<T>(dev_ctx, *pt_x.get(), pt_out.get());
} }
}; };
......
...@@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel { ...@@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"dot_grad", {"X", "Y", framework::GradVarName("Out")}, {},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
}; };
template <typename T> template <typename T>
......
...@@ -22,217 +22,14 @@ ...@@ -22,217 +22,14 @@
// only can include the headers in paddle/pten/api dirs // only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/linalg.h" #include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/dot_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, typename R>
struct P {
void operator()(T a, R b);
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
} else {
auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = framework::EigenMatrix<T>::From(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = framework::EigenMatrix<T>::From(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_x->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = framework::EigenMatrix<T>::From(*tensor_y);
auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = framework::EigenMatrix<T>::From(*tensor_x);
auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
*dz = tensor_dout->data<T>();
auto&& d = tensor_x->dims();
auto const N = tensor_x->numel();
auto const B = d[d.size() - 1];
if (tensor_dx) {
auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
}
}
if (tensor_dy) {
auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
}
}
#endif
}
};
// See Note [ Why still keep the original kernel implementation? ] // See Note [ Why still keep the original kernel implementation? ]
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> { class DotKernel : public framework::OpKernel<T> {
...@@ -249,7 +46,7 @@ class DotKernel : public framework::OpKernel<T> { ...@@ -249,7 +46,7 @@ class DotKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel // call new kernel
pten::Dot<T>(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get()); pten::DotKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get());
} }
}; };
...@@ -266,8 +63,17 @@ class DotGradKernel : public framework::OpKernel<T> { ...@@ -266,8 +63,17 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace()); if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout, auto pt_x = paddle::experimental::MakePtenDenseTensor(*tensor_x);
tensor_dx, tensor_dy, ctx); auto pt_y = paddle::experimental::MakePtenDenseTensor(*tensor_y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*tensor_dout);
auto pt_dx = paddle::experimental::MakePtenDenseTensor(*tensor_dx);
auto pt_dy = paddle::experimental::MakePtenDenseTensor(*tensor_dy);
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::DotGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, pt_dx.get(),
pt_dy.get());
} }
}; };
......
...@@ -225,6 +225,10 @@ class Blas { ...@@ -225,6 +225,10 @@ class Blas {
const framework::Tensor& mat_b, const MatDescriptor& dim_b, const framework::Tensor& mat_b, const MatDescriptor& dim_b,
T alpha, framework::Tensor* mat_out, T beta) const; T alpha, framework::Tensor* mat_out, T beta) const;
template <typename T>
void MatMul(const T* mat_a, const MatDescriptor& dim_a, const T* mat_b,
const MatDescriptor& dim_b, T alpha, T* mat_out, T beta) const;
template <typename T> template <typename T>
void VINV(int n, const T* a, T* y) const; void VINV(int n, const T* a, T* y) const;
......
...@@ -1249,6 +1249,15 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, ...@@ -1249,6 +1249,15 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
const framework::Tensor &mat_b, const framework::Tensor &mat_b,
const MatDescriptor &dim_b, T alpha, const MatDescriptor &dim_b, T alpha,
framework::Tensor *mat_out, T beta) const { framework::Tensor *mat_out, T beta) const {
MatMul(mat_a.data<T>(), dim_a, mat_b.data<T>(), dim_b, alpha,
mat_out->data<T>(), beta);
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const T *mat_a, const MatDescriptor &dim_a,
const T *mat_b, const MatDescriptor &dim_b,
T alpha, T *mat_out, T beta) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_a.width_, dim_b.height_, dim_a.width_, dim_b.height_,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -1261,8 +1270,7 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, ...@@ -1261,8 +1270,7 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_, this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(), dim_a.width_, alpha, mat_a, mat_b, beta, mat_out);
mat_b.data<T>(), beta, mat_out->data<T>());
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 ||
...@@ -1273,8 +1281,8 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, ...@@ -1273,8 +1281,8 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
"But got dim_a.batch_size = %d, dim_b.batch_size = %d.", "But got dim_a.batch_size = %d, dim_b.batch_size = %d.",
dim_a.batch_size_, dim_b.batch_size_)); dim_a.batch_size_, dim_b.batch_size_));
this->template BatchedGEMM<T>( this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(), mat_b, beta, mat_out,
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_); dim_a.stride_, dim_b.stride_);
} }
......
...@@ -389,6 +389,14 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { ...@@ -389,6 +389,14 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_grad", {"X", "Y", framework::GradVarName("Out")},
{"trans_x", "trans_y"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
}; };
template <typename T> template <typename T>
...@@ -431,6 +439,13 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel { ...@@ -431,6 +439,13 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
context->ShareDim("DOut", "DDOut"); context->ShareDim("DOut", "DDOut");
} }
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"}, {"DX", "DY", "DDOut"});
}
}; };
template <typename T> template <typename T>
...@@ -500,6 +515,15 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel { ...@@ -500,6 +515,15 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
context->ShareDim("Y", "D_DDY_out"); context->ShareDim("Y", "D_DDY_out");
} }
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}
}; };
template <typename T> template <typename T>
......
...@@ -70,6 +70,12 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { ...@@ -70,6 +70,12 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
return *this; return *this;
} }
DenseTensor& DenseTensor::operator=(DenseTensor&& other) {
meta_ = std::move(other.meta_);
storage_.swap(other.storage_);
return *this;
}
int64_t DenseTensor::numel() const { int64_t DenseTensor::numel() const {
if (meta_.is_scalar) { if (meta_.is_scalar) {
return 1; return 1;
......
...@@ -97,6 +97,8 @@ class DenseTensor : public TensorBase, ...@@ -97,6 +97,8 @@ class DenseTensor : public TensorBase,
/// \brief DenseTensor shallow copy assignment. /// \brief DenseTensor shallow copy assignment.
DenseTensor& operator=(const DenseTensor& other); DenseTensor& operator=(const DenseTensor& other);
DenseTensor& operator=(DenseTensor&& other);
/// \brief Destroy the tensor object and release exclusive resources. /// \brief Destroy the tensor object and release exclusive resources.
virtual ~DenseTensor() = default; virtual ~DenseTensor() = default;
......
...@@ -29,6 +29,9 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = { ...@@ -29,6 +29,9 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"flatten_contiguous_range", "flatten"}, {"flatten_contiguous_range", "flatten"},
{"flatten_contiguous_range_grad", "flatten_grad"}, {"flatten_contiguous_range_grad", "flatten_grad"},
{"matmul_v2", "matmul"}, {"matmul_v2", "matmul"},
{"matmul_v2_grad", "matmul_grad"},
{"matmul_v2_grad_grad", "matmul_double_grad"},
{"matmul_v2_triple_grad", "matmul_triple_grad"},
{"reduce_mean", "mean"}, {"reduce_mean", "mean"},
{"reduce_sum", "sum"}, {"reduce_sum", "sum"},
{"reshape2", "reshape"}, {"reshape2", "reshape"},
...@@ -36,6 +39,8 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = { ...@@ -36,6 +39,8 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"flatten", "deprecated"}, {"flatten", "deprecated"},
{"flatten_grad", "deprecated"}, {"flatten_grad", "deprecated"},
{"matmul", "deprecated"}, {"matmul", "deprecated"},
{"matmul_grad", "deprecated"},
{"matmul_grad_grad", "deprecated"},
{"mean", "deprecated"}, {"mean", "deprecated"},
{"reshape", "deprecated"}, {"reshape", "deprecated"},
{"sum", "deprecated"}}; {"sum", "deprecated"}};
......
...@@ -50,6 +50,11 @@ void KernelContext::EmplaceBackOutputWithoutSetRange( ...@@ -50,6 +50,11 @@ void KernelContext::EmplaceBackOutputWithoutSetRange(
outputs_.emplace_back(std::move(output)); outputs_.emplace_back(std::move(output));
} }
void KernelContext::SetOutputWithoutSetRange(
int index, std::shared_ptr<TensorBase> output) {
outputs_.at(index) = std::move(output);
}
void KernelContext::EmplaceBackOutputs( void KernelContext::EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) { paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) {
int index = outputs_.size(); int index = outputs_.size();
...@@ -119,8 +124,10 @@ void KernelContext::ClearData() { ...@@ -119,8 +124,10 @@ void KernelContext::ClearData() {
} }
} }
for (auto& out : outputs_) { for (auto& out : outputs_) {
CompatibleDenseTensorUtils::ClearStorage( if (out) {
static_cast<DenseTensor*>(out.get())); CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(out.get()));
}
} }
attrs_.clear(); attrs_.clear();
} }
......
...@@ -62,6 +62,8 @@ class KernelContext { ...@@ -62,6 +62,8 @@ class KernelContext {
void EmplaceBackOutputWithoutSetRange(std::shared_ptr<TensorBase> output); void EmplaceBackOutputWithoutSetRange(std::shared_ptr<TensorBase> output);
void SetOutputWithoutSetRange(int index, std::shared_ptr<TensorBase> output);
void EmplaceBackOutputs( void EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs); paddle::SmallVector<std::shared_ptr<TensorBase>> outputs);
...@@ -80,6 +82,14 @@ class KernelContext { ...@@ -80,6 +82,14 @@ class KernelContext {
return static_cast<const TensorType&>(*(inputs_.at(idx))); return static_cast<const TensorType&>(*(inputs_.at(idx)));
} }
template <typename TensorType>
paddle::optional<const TensorType&> OptionalInputAt(size_t idx) const {
const auto& input = inputs_.at(idx);
return input ? paddle::optional<const TensorType&>{static_cast<
const TensorType&>(*input)}
: paddle::optional<const TensorType&>{paddle::none};
}
std::shared_ptr<TensorBase>& MutableInputPtrAt(size_t idx) { std::shared_ptr<TensorBase>& MutableInputPtrAt(size_t idx) {
return inputs_.at(idx); return inputs_.at(idx);
} }
......
...@@ -65,6 +65,10 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -65,6 +65,10 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
} else if (arg_type == std::type_index(typeid(const DenseTensor&))) { } else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
args_def->AppendInput( args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype()); default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(
paddle::optional<const DenseTensor&>))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == } else if (arg_type ==
std::type_index(typeid(const std::vector<DenseTensor>&))) { std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput( args_def->AppendInput(
......
...@@ -77,6 +77,27 @@ namespace pten { ...@@ -77,6 +77,27 @@ namespace pten {
} \ } \
} }
#define PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(tensor_type) \
template <typename... Tail> \
struct KernelCallHelper<paddle::optional<const tensor_type&>, Tail...> { \
template <int dev_ctx_idx, \
int in_idx, \
int attr_idx, \
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(attr_idx == 0, \
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
auto arg = ctx->OptionalInputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs..., arg); \
} \
}
#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ #define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \
template <typename... Tail> \ template <typename... Tail> \
struct KernelCallHelper<const std::vector<tensor_type>&, Tail...> { \ struct KernelCallHelper<const std::vector<tensor_type>&, Tail...> { \
...@@ -190,6 +211,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -190,6 +211,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
/* Input Helpers */ /* Input Helpers */
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
// TODO(chenweihang): adapt SelectedRows // TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor); // PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor);
......
...@@ -30,7 +30,7 @@ DenseTensor Dot(const ContextT& dev_ctx, ...@@ -30,7 +30,7 @@ DenseTensor Dot(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>( pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()), dev_ctx.GetPlace()),
std::move(out_meta)); std::move(out_meta));
Dot<T, ContextT>(dev_ctx, x, y, &dense_out); DotKernel<T, ContextT>(dev_ctx, x, y, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -48,15 +48,4 @@ DenseTensor Scale(const ContextT& dev_ctx, ...@@ -48,15 +48,4 @@ DenseTensor Scale(const ContextT& dev_ctx,
return dense_out; return dense_out;
} }
template <typename T, typename ContextT>
DenseTensor Conj(const ContextT& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta());
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Conj<T>(dev_ctx, x, &dense_out);
return dense_out;
}
} // namespace pten } // namespace pten
...@@ -16,9 +16,20 @@ limitations under the License. */ ...@@ -16,9 +16,20 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta());
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
ConjKernel<T>(dev_ctx, x, &dense_out);
return dense_out;
}
} // namespace pten } // namespace pten
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
PT_REGISTER_CTX_KERNEL(conj, PT_REGISTER_CTX_KERNEL(conj,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Conj, pten::ConjKernel,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>, paddle::platform::complex<double>,
float, float,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad,
CPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void Dot(const Context& dev_ctx, void DotKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
auto const *x_ptr = x.data<T>(), *x_ptr_ = &x_ptr[0]; auto const *x_ptr = x.data<T>(), *x_ptr_ = &x_ptr[0];
auto const *y_ptr = y.data<T>(), *y_ptr_ = &y_ptr[0]; auto const *y_ptr = y.data<T>(), *y_ptr_ = &y_ptr[0];
auto* z = out->mutable_data<T>(); auto* z = out->mutable_data<T>();
...@@ -52,7 +52,7 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -52,7 +52,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot, PT_REGISTER_CTX_KERNEL(dot,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Dot, pten::DotKernel,
float, float,
double, double,
int, int,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/matmul_grad_kernel.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad,
CPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad,
CPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
CPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void DotGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void DotDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout);
template <typename T, typename Context>
void DotTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& d_dx,
const DenseTensor& d_dy,
const DenseTensor& dout,
const DenseTensor& d_ddout,
DenseTensor* d_x,
DenseTensor* d_y,
DenseTensor* d_ddx,
DenseTensor* d_ddy,
DenseTensor* d_dout);
} // namespace pten
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void Dot(const Context& dev_ctx, void DotKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out); DenseTensor* out);
} // namespace pten } // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/empty_kernel.h" #include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/backends/all_context.h" #include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void EmptyKernel(const ContextT& dev_ctx, void EmptyKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData())); out->Resize(paddle::framework::make_ddim(shape.GetData()));
} }
template <typename T, typename ContextT> template <typename T, typename Context>
void EmptyLikeKernel(const ContextT& dev_ctx, DenseTensor* out) { void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
out->mutable_data<T>(); out->mutable_data<T>();
} }
...@@ -37,44 +38,62 @@ PT_REGISTER_CTX_KERNEL(empty, ...@@ -37,44 +38,62 @@ PT_REGISTER_CTX_KERNEL(empty,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyKernel, pten::EmptyKernel,
bool,
int,
int64_t,
float, float,
double, double,
paddle::platform::float16) {} uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like, PT_REGISTER_CTX_KERNEL(empty_like,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyLikeKernel, pten::EmptyLikeKernel,
bool,
int,
int64_t,
float, float,
double, double,
paddle::platform::float16) {} uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(empty, PT_REGISTER_CTX_KERNEL(empty,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyKernel, pten::EmptyKernel,
bool,
int,
int64_t,
float, float,
double, double,
paddle::platform::float16) {} uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like, PT_REGISTER_CTX_KERNEL(empty_like,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyLikeKernel, pten::EmptyLikeKernel,
bool,
int,
int64_t,
float, float,
double, double,
paddle::platform::float16) {} uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#endif #endif
...@@ -41,6 +41,14 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) { ...@@ -41,6 +41,14 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) {
return dense_out; return dense_out;
} }
template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx) {
return Empty<T, Context>(dev_ctx,
{paddle::experimental::CppTypeToDataType<T>::Type(),
{-1},
DataLayout::NCHW});
}
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx, DenseTensor Empty(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
PT_REGISTER_CTX_KERNEL(conj, PT_REGISTER_CTX_KERNEL(conj,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Conj, pten::ConjKernel,
paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>, paddle::platform::complex<double>,
float, float,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad,
GPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
...@@ -25,10 +25,10 @@ ...@@ -25,10 +25,10 @@
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void Dot(const Context& dev_ctx, void DotKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
out->mutable_data<T>(); out->mutable_data<T>();
if (1 == out->dims().size()) { if (1 == out->dims().size()) {
auto eigen_out = pten::EigenScalar<T>::From(*out); auto eigen_out = pten::EigenScalar<T>::From(*out);
...@@ -55,7 +55,7 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -55,7 +55,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot, PT_REGISTER_CTX_KERNEL(dot,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Dot, pten::DotKernel,
float, float,
double, double,
int, int,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/matmul_grad_kernel.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad,
GPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad,
GPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
GPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/pten/kernels/hybird/eigen/common.h"
namespace pten { namespace pten {
namespace math { namespace math {
...@@ -30,5 +33,30 @@ struct TransposeNormal { ...@@ -30,5 +33,30 @@ struct TransposeNormal {
const std::vector<int64_t>& axis); const std::vector<int64_t>& axis);
}; };
template <typename DeviceContext, typename T, int Rank>
struct Transpose {
void operator()(const DeviceContext& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto eigen_in = pten::EigenTensor<T, Rank>::From(in);
auto eigen_out = pten::EigenTensor<T, Rank>::From(*out);
auto* dev = dev_ctx.eigen_device();
// use 32bit index to speed up computation
bool use_32bit_index = eigen_out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace());
if (use_32bit_index && is_gpu_place) {
To32BitIndex(eigen_out).device(*dev) =
To32BitIndex(eigen_in).shuffle(permute);
} else {
eigen_out.device(*dev) = eigen_in.shuffle(permute);
}
}
};
} // namespace math } // namespace math
} // namespace pten } // namespace pten
...@@ -21,12 +21,14 @@ ...@@ -21,12 +21,14 @@
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { void ConjKernel(const Context& context,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel(); auto numel = x.numel();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
auto* out_data = out->mutable_data<T>(); auto* out_data = out->mutable_data<T>();
paddle::platform::ForRange<Context> for_range(dev_ctx, numel); paddle::platform::ForRange<Context> for_range(context, numel);
paddle::operators::math::ConjFunctor<T> functor(x_data, numel, out_data); paddle::operators::math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor); for_range(functor);
} }
......
此差异已折叠。
此差异已折叠。
...@@ -86,7 +86,7 @@ static void IndexIncreaseFromDims(const int ndim, ...@@ -86,7 +86,7 @@ static void IndexIncreaseFromDims(const int ndim,
} }
template <typename Context, typename T> template <typename Context, typename T>
void MatMulFunction(const Context& context, void MatMulFunction(const Context& dev_ctx,
const DenseTensor& X, const DenseTensor& X,
const DenseTensor& Y, const DenseTensor& Y,
const std::vector<std::int64_t>& x_dims, const std::vector<std::int64_t>& x_dims,
...@@ -102,7 +102,7 @@ void MatMulFunction(const Context& context, ...@@ -102,7 +102,7 @@ void MatMulFunction(const Context& context,
const T* x_data = X.data<T>(); const T* x_data = X.data<T>();
const T* y_data = Y.data<T>(); const T* y_data = Y.data<T>();
auto blas = paddle::operators::math::GetBlas<Context, T>(context); auto blas = paddle::operators::math::GetBlas<Context, T>(dev_ctx);
if (x_ndim == 1 && y_ndim == 1) { if (x_ndim == 1 && y_ndim == 1) {
const int M = X.numel(); const int M = X.numel();
...@@ -117,6 +117,8 @@ void MatMulFunction(const Context& context, ...@@ -117,6 +117,8 @@ void MatMulFunction(const Context& context,
M, M,
N)); N));
VLOG(3) << "MatMul's case 1"; VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>();
blas.GEMM(CblasNoTrans, blas.GEMM(CblasNoTrans,
CblasTrans, CblasTrans,
1, 1,
...@@ -471,7 +473,7 @@ void MatMulFunction(const Context& context, ...@@ -471,7 +473,7 @@ void MatMulFunction(const Context& context,
} }
template <typename Context, typename T> template <typename Context, typename T>
void MatMulFunction(const Context& context, void MatMulFunction(const Context& dev_ctx,
const DenseTensor& X, const DenseTensor& X,
const DenseTensor& Y, const DenseTensor& Y,
DenseTensor* Out, DenseTensor* Out,
...@@ -481,11 +483,11 @@ void MatMulFunction(const Context& context, ...@@ -481,11 +483,11 @@ void MatMulFunction(const Context& context,
const std::vector<std::int64_t> x_dims = vectorize(X.dims()); const std::vector<std::int64_t> x_dims = vectorize(X.dims());
const std::vector<std::int64_t> y_dims = vectorize(Y.dims()); const std::vector<std::int64_t> y_dims = vectorize(Y.dims());
MatMulFunction<Context, T>( MatMulFunction<Context, T>(
context, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag);
} }
template <typename T, typename Context> template <typename T, typename Context>
void MatmulKernel(const Context& context, void MatmulKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
bool transpose_x, bool transpose_x,
...@@ -501,7 +503,7 @@ void MatmulKernel(const Context& context, ...@@ -501,7 +503,7 @@ void MatmulKernel(const Context& context,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0," "The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. ")); " but reviced dims size is 0. "));
MatMulFunction<Context, T>(context, x, y, out, transpose_x, transpose_y); MatMulFunction<Context, T>(dev_ctx, x, y, out, transpose_x, transpose_y);
} }
} // namespace pten } // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace pten {
template <typename T, typename Context>
void MatmulGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
bool transpose_x,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void MatmulDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
bool transpose_x,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout);
template <typename T, typename Context>
void MatmulTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& ddy,
paddle::optional<const DenseTensor&> d_dx,
paddle::optional<const DenseTensor&> d_dy,
paddle::optional<const DenseTensor&> d_ddout,
bool transpose_x,
bool transpose_y,
DenseTensor* out_d_x,
DenseTensor* out_d_y,
DenseTensor* out_d_dout,
DenseTensor* out_d_ddx,
DenseTensor* out_d_ddy);
} // namespace pten
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
#pragma once #pragma once
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/binary.h" #include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void MatmulKernel(const Context& context, void MatmulKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
bool transpose_x, bool transpose_x,
...@@ -29,17 +30,14 @@ void MatmulKernel(const Context& context, ...@@ -29,17 +30,14 @@ void MatmulKernel(const Context& context,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Matmul(const Context& context, DenseTensor Matmul(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
bool transpose_x, bool transpose_x,
bool transpose_y) { bool transpose_y) {
auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y); auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y);
DenseTensor dense_out( auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
pten::make_intrusive<paddle::experimental::SharedStorage>( MatmulKernel<T, Context>(dev_ctx, x, y, transpose_x, transpose_y, &dense_out);
context.GetPlace()),
std::move(out_meta));
MatmulKernel<T, Context>(context, x, y, transpose_x, transpose_y, &dense_out);
return dense_out; return dense_out;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册