未验证 提交 be817719 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Add dot and matmul grad kernel in pten (#38713)

* refactor matmul directory in pten

* fix merge conflict

* add dot_grad kernel

* add dot_grad kernel in pten

* add matmul_grad kernel

* update the code

* delete useless code in fluid

* fix some bug of running matmul grad kernel

* fix merge conflict

* refactor some code

* refactor code
上级 5b940c44
......@@ -79,6 +79,9 @@ function(kernel_library TARGET)
endif()
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
endif()
list(APPEND all_srcs ${common_srcs})
list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs})
......
......@@ -1880,16 +1880,32 @@ void OperatorWithKernel::BuildPtenKernelContext(
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset], out_def,
auto* buffer_tensor =
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx +
offset));
offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset],
out_def, buffer_tensor);
}
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}
}
// Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (outs_vector.empty()) {
if (current_vector_size > start_idx) {
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr});
}
end_idx = start_idx + 1;
}
pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx),
i);
}
......@@ -2002,9 +2018,11 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const {
range_pair.first, range_pair.second);
for (size_t j = 0; j < pten_outs.size(); ++j) {
if (pten_outs[j]) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
}
}
}
}
} // namespace framework
......
......@@ -99,7 +99,7 @@ KernelSignatureMap& KernelSignatureMap::Instance() {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto != nullptr) {
op_proto) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_
......
......@@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext(
for (size_t i = 0; i < output_names.size(); ++i) {
auto& out_def = output_defs.at(i);
auto& outs_vector = outs.at(output_names[i]);
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
size_t end_idx = start_idx + outs_vector.size();
auto current_vector_size = kernel_ctx->OutputsSize();
auto iter = outs.find(output_names[i]);
if (iter == outs.end()) {
if (current_vector_size > start_idx) {
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i);
continue;
}
auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
auto* buffer_tensor =
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def,
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset));
outs_vector[offset]->MutableVar(), out_def, buffer_tensor);
} else {
kernel_ctx->SetOutputWithoutSetRange(
start_idx + offset,
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(
......@@ -465,7 +487,9 @@ static void WriteBackToOutputs(
auto& output_names = std::get<2>(pt_kernel_signature.args);
for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = outs.at(output_names[i]);
auto iter = outs.find(output_names[i]);
if (iter != outs.end()) {
auto& outs_vector = iter->second;
auto& range_pair = kernel_ctx->OutputRangeAt(i);
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>(
......@@ -476,6 +500,7 @@ static void WriteBackToOutputs(
outs_vector[j]->MutableVar());
}
}
}
}
template <typename VarType>
......@@ -529,6 +554,7 @@ static void PreparedOpRunImpl(
template <typename VarType>
static void PreparedOpRunPtImpl(
const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
......@@ -558,7 +584,9 @@ static void PreparedOpRunPtImpl(
pt_kernel_context->ClearData();
// TODO(chenweihang): add debug flags later
// TODO(chenweihang): deal with complex cases later
if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs);
}
}
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
......@@ -566,9 +594,9 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, pt_kernel_signature_, pt_kernel_,
pt_kernel_context_, dev_ctx_, ins, outs, attrs,
default_attrs);
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs);
......@@ -580,9 +608,9 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(op_, pt_kernel_signature_, pt_kernel_,
pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_,
dev_ctx_, ins, outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs);
......
......@@ -39,7 +39,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::Conj<T>(dev_ctx, *pt_x.get(), pt_out.get());
pten::ConjKernel<T>(dev_ctx, *pt_x.get(), pt_out.get());
}
};
......
......@@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"dot_grad", {"X", "Y", framework::GradVarName("Out")}, {},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};
template <typename T>
......
......@@ -22,217 +22,14 @@
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/linalg.h"
#include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/dot_kernel.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, typename R>
struct P {
void operator()(T a, R b);
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
} else {
auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = framework::EigenMatrix<T>::From(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = framework::EigenMatrix<T>::From(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_x->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = framework::EigenMatrix<T>::From(*tensor_y);
auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = framework::EigenMatrix<T>::From(*tensor_x);
auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
*dz = tensor_dout->data<T>();
auto&& d = tensor_x->dims();
auto const N = tensor_x->numel();
auto const B = d[d.size() - 1];
if (tensor_dx) {
auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
}
}
if (tensor_dy) {
auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
}
}
#endif
}
};
// See Note [ Why still keep the original kernel implementation? ]
template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
......@@ -249,7 +46,7 @@ class DotKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::Dot<T>(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get());
pten::DotKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get());
}
};
......@@ -266,8 +63,17 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
auto pt_x = paddle::experimental::MakePtenDenseTensor(*tensor_x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*tensor_y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*tensor_dout);
auto pt_dx = paddle::experimental::MakePtenDenseTensor(*tensor_dx);
auto pt_dy = paddle::experimental::MakePtenDenseTensor(*tensor_dy);
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::DotGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, pt_dx.get(),
pt_dy.get());
}
};
......
......@@ -225,6 +225,10 @@ class Blas {
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
T alpha, framework::Tensor* mat_out, T beta) const;
template <typename T>
void MatMul(const T* mat_a, const MatDescriptor& dim_a, const T* mat_b,
const MatDescriptor& dim_b, T alpha, T* mat_out, T beta) const;
template <typename T>
void VINV(int n, const T* a, T* y) const;
......
......@@ -1249,6 +1249,15 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
const framework::Tensor &mat_b,
const MatDescriptor &dim_b, T alpha,
framework::Tensor *mat_out, T beta) const {
MatMul(mat_a.data<T>(), dim_a, mat_b.data<T>(), dim_b, alpha,
mat_out->data<T>(), beta);
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const T *mat_a, const MatDescriptor &dim_a,
const T *mat_b, const MatDescriptor &dim_b,
T alpha, T *mat_out, T beta) const {
PADDLE_ENFORCE_EQ(
dim_a.width_, dim_b.height_,
platform::errors::InvalidArgument(
......@@ -1261,8 +1270,7 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
dim_a.width_, alpha, mat_a, mat_b, beta, mat_out);
} else {
PADDLE_ENFORCE_EQ(
dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 ||
......@@ -1273,8 +1281,8 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
"But got dim_a.batch_size = %d, dim_b.batch_size = %d.",
dim_a.batch_size_, dim_b.batch_size_));
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a,
mat_b, beta, mat_out,
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_);
}
......
......@@ -389,6 +389,14 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_grad", {"X", "Y", framework::GradVarName("Out")},
{"trans_x", "trans_y"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};
template <typename T>
......@@ -431,6 +439,13 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
context->ShareDim("DOut", "DDOut");
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"}, {"DX", "DY", "DDOut"});
}
};
template <typename T>
......@@ -500,6 +515,15 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
context->ShareDim("Y", "D_DDY_out");
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}
};
template <typename T>
......
......@@ -70,6 +70,12 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
return *this;
}
DenseTensor& DenseTensor::operator=(DenseTensor&& other) {
meta_ = std::move(other.meta_);
storage_.swap(other.storage_);
return *this;
}
int64_t DenseTensor::numel() const {
if (meta_.is_scalar) {
return 1;
......
......@@ -97,6 +97,8 @@ class DenseTensor : public TensorBase,
/// \brief DenseTensor shallow copy assignment.
DenseTensor& operator=(const DenseTensor& other);
DenseTensor& operator=(DenseTensor&& other);
/// \brief Destroy the tensor object and release exclusive resources.
virtual ~DenseTensor() = default;
......
......@@ -29,6 +29,9 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"flatten_contiguous_range", "flatten"},
{"flatten_contiguous_range_grad", "flatten_grad"},
{"matmul_v2", "matmul"},
{"matmul_v2_grad", "matmul_grad"},
{"matmul_v2_grad_grad", "matmul_double_grad"},
{"matmul_v2_triple_grad", "matmul_triple_grad"},
{"reduce_mean", "mean"},
{"reduce_sum", "sum"},
{"reshape2", "reshape"},
......@@ -36,6 +39,8 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"flatten", "deprecated"},
{"flatten_grad", "deprecated"},
{"matmul", "deprecated"},
{"matmul_grad", "deprecated"},
{"matmul_grad_grad", "deprecated"},
{"mean", "deprecated"},
{"reshape", "deprecated"},
{"sum", "deprecated"}};
......
......@@ -50,6 +50,11 @@ void KernelContext::EmplaceBackOutputWithoutSetRange(
outputs_.emplace_back(std::move(output));
}
void KernelContext::SetOutputWithoutSetRange(
int index, std::shared_ptr<TensorBase> output) {
outputs_.at(index) = std::move(output);
}
void KernelContext::EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) {
int index = outputs_.size();
......@@ -119,9 +124,11 @@ void KernelContext::ClearData() {
}
}
for (auto& out : outputs_) {
if (out) {
CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(out.get()));
}
}
attrs_.clear();
}
} // namespace pten
......@@ -62,6 +62,8 @@ class KernelContext {
void EmplaceBackOutputWithoutSetRange(std::shared_ptr<TensorBase> output);
void SetOutputWithoutSetRange(int index, std::shared_ptr<TensorBase> output);
void EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs);
......@@ -80,6 +82,14 @@ class KernelContext {
return static_cast<const TensorType&>(*(inputs_.at(idx)));
}
template <typename TensorType>
paddle::optional<const TensorType&> OptionalInputAt(size_t idx) const {
const auto& input = inputs_.at(idx);
return input ? paddle::optional<const TensorType&>{static_cast<
const TensorType&>(*input)}
: paddle::optional<const TensorType&>{paddle::none};
}
std::shared_ptr<TensorBase>& MutableInputPtrAt(size_t idx) {
return inputs_.at(idx);
}
......
......@@ -65,6 +65,10 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
} else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(
paddle::optional<const DenseTensor&>))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type ==
std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(
......
......@@ -77,6 +77,27 @@ namespace pten {
} \
}
#define PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(tensor_type) \
template <typename... Tail> \
struct KernelCallHelper<paddle::optional<const tensor_type&>, Tail...> { \
template <int dev_ctx_idx, \
int in_idx, \
int attr_idx, \
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(attr_idx == 0, \
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
auto arg = ctx->OptionalInputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs..., arg); \
} \
}
#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \
template <typename... Tail> \
struct KernelCallHelper<const std::vector<tensor_type>&, Tail...> { \
......@@ -190,6 +211,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
/* Input Helpers */
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
// TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor);
......
......@@ -30,7 +30,7 @@ DenseTensor Dot(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Dot<T, ContextT>(dev_ctx, x, y, &dense_out);
DotKernel<T, ContextT>(dev_ctx, x, y, &dense_out);
return dense_out;
}
......
......@@ -48,15 +48,4 @@ DenseTensor Scale(const ContextT& dev_ctx,
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Conj(const ContextT& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta());
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Conj<T>(dev_ctx, x, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -16,9 +16,20 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename T, typename Context>
void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta());
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
ConjKernel<T>(dev_ctx, x, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -24,7 +24,7 @@
PT_REGISTER_CTX_KERNEL(conj,
CPU,
ALL_LAYOUT,
pten::Conj,
pten::ConjKernel,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad,
CPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -23,7 +23,7 @@
namespace pten {
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
void DotKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
......@@ -52,7 +52,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot,
CPU,
ALL_LAYOUT,
pten::Dot,
pten::DotKernel,
float,
double,
int,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/matmul_grad_kernel.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad,
CPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad,
CPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
CPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void DotGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void DotDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout);
template <typename T, typename Context>
void DotTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& d_dx,
const DenseTensor& d_dy,
const DenseTensor& dout,
const DenseTensor& d_ddout,
DenseTensor* d_x,
DenseTensor* d_y,
DenseTensor* d_ddx,
DenseTensor* d_ddy,
DenseTensor* d_dout);
} // namespace pten
......@@ -19,7 +19,7 @@
namespace pten {
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
void DotKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace pten {
template <typename T, typename ContextT>
void EmptyKernel(const ContextT& dev_ctx,
template <typename T, typename Context>
void EmptyKernel(const Context& dev_ctx,
const ScalarArray& shape,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
}
template <typename T, typename ContextT>
void EmptyLikeKernel(const ContextT& dev_ctx, DenseTensor* out) {
template <typename T, typename Context>
void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
out->mutable_data<T>();
}
......@@ -37,44 +38,62 @@ PT_REGISTER_CTX_KERNEL(empty,
CPU,
ALL_LAYOUT,
pten::EmptyKernel,
bool,
int,
int64_t,
float,
double,
paddle::platform::float16) {}
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like,
CPU,
ALL_LAYOUT,
pten::EmptyLikeKernel,
bool,
int,
int64_t,
float,
double,
paddle::platform::float16) {}
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(empty,
GPU,
ALL_LAYOUT,
pten::EmptyKernel,
bool,
int,
int64_t,
float,
double,
paddle::platform::float16) {}
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like,
GPU,
ALL_LAYOUT,
pten::EmptyLikeKernel,
bool,
int,
int64_t,
float,
double,
paddle::platform::float16) {}
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#endif
......@@ -41,6 +41,14 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) {
return dense_out;
}
template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx) {
return Empty<T, Context>(dev_ctx,
{paddle::experimental::CppTypeToDataType<T>::Type(),
{-1},
DataLayout::NCHW});
}
template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx,
const ScalarArray& shape,
......
......@@ -24,7 +24,8 @@
PT_REGISTER_CTX_KERNEL(conj,
GPU,
ALL_LAYOUT,
pten::Conj,
pten::ConjKernel,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad,
GPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -25,7 +25,7 @@
namespace pten {
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
void DotKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
......@@ -55,7 +55,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot,
GPU,
ALL_LAYOUT,
pten::Dot,
pten::DotKernel,
float,
double,
int,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/matmul_grad_kernel.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad,
GPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad,
GPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
GPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -17,6 +17,9 @@
#include "paddle/fluid/framework/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/pten/kernels/hybird/eigen/common.h"
namespace pten {
namespace math {
......@@ -30,5 +33,30 @@ struct TransposeNormal {
const std::vector<int64_t>& axis);
};
template <typename DeviceContext, typename T, int Rank>
struct Transpose {
void operator()(const DeviceContext& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto eigen_in = pten::EigenTensor<T, Rank>::From(in);
auto eigen_out = pten::EigenTensor<T, Rank>::From(*out);
auto* dev = dev_ctx.eigen_device();
// use 32bit index to speed up computation
bool use_32bit_index = eigen_out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace());
if (use_32bit_index && is_gpu_place) {
To32BitIndex(eigen_out).device(*dev) =
To32BitIndex(eigen_in).shuffle(permute);
} else {
eigen_out.device(*dev) = eigen_in.shuffle(permute);
}
}
};
} // namespace math
} // namespace pten
......@@ -21,12 +21,14 @@
namespace pten {
template <typename T, typename Context>
void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
void ConjKernel(const Context& context,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = out->mutable_data<T>();
paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
paddle::platform::ForRange<Context> for_range(context, numel);
paddle::operators::math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor);
}
......
此差异已折叠。
此差异已折叠。
......@@ -86,7 +86,7 @@ static void IndexIncreaseFromDims(const int ndim,
}
template <typename Context, typename T>
void MatMulFunction(const Context& context,
void MatMulFunction(const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& Y,
const std::vector<std::int64_t>& x_dims,
......@@ -102,7 +102,7 @@ void MatMulFunction(const Context& context,
const T* x_data = X.data<T>();
const T* y_data = Y.data<T>();
auto blas = paddle::operators::math::GetBlas<Context, T>(context);
auto blas = paddle::operators::math::GetBlas<Context, T>(dev_ctx);
if (x_ndim == 1 && y_ndim == 1) {
const int M = X.numel();
......@@ -117,6 +117,8 @@ void MatMulFunction(const Context& context,
M,
N));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>();
blas.GEMM(CblasNoTrans,
CblasTrans,
1,
......@@ -471,7 +473,7 @@ void MatMulFunction(const Context& context,
}
template <typename Context, typename T>
void MatMulFunction(const Context& context,
void MatMulFunction(const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& Y,
DenseTensor* Out,
......@@ -481,11 +483,11 @@ void MatMulFunction(const Context& context,
const std::vector<std::int64_t> x_dims = vectorize(X.dims());
const std::vector<std::int64_t> y_dims = vectorize(Y.dims());
MatMulFunction<Context, T>(
context, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag);
dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag);
}
template <typename T, typename Context>
void MatmulKernel(const Context& context,
void MatmulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
......@@ -501,7 +503,7 @@ void MatmulKernel(const Context& context,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
MatMulFunction<Context, T>(context, x, y, out, transpose_x, transpose_y);
MatMulFunction<Context, T>(dev_ctx, x, y, out, transpose_x, transpose_y);
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace pten {
template <typename T, typename Context>
void MatmulGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
bool transpose_x,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void MatmulDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
bool transpose_x,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout);
template <typename T, typename Context>
void MatmulTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& ddy,
paddle::optional<const DenseTensor&> d_dx,
paddle::optional<const DenseTensor&> d_dy,
paddle::optional<const DenseTensor&> d_ddout,
bool transpose_x,
bool transpose_y,
DenseTensor* out_d_x,
DenseTensor* out_d_y,
DenseTensor* out_d_dout,
DenseTensor* out_d_ddx,
DenseTensor* out_d_ddy);
} // namespace pten
......@@ -14,14 +14,15 @@
#pragma once
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename T, typename Context>
void MatmulKernel(const Context& context,
void MatmulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
......@@ -29,17 +30,14 @@ void MatmulKernel(const Context& context,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Matmul(const Context& context,
DenseTensor Matmul(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y) {
auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y);
DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
context.GetPlace()),
std::move(out_meta));
MatmulKernel<T, Context>(context, x, y, transpose_x, transpose_y, &dense_out);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
MatmulKernel<T, Context>(dev_ctx, x, y, transpose_x, transpose_y, &dense_out);
return dense_out;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册