未验证 提交 d7dddf94 编写于 作者: H hong 提交者: GitHub

Move trace op to pten (#39227)

* add trace op

* bug fix

* bug fix; test=develop

* thrust bug fix; test=develop

* remove useless register; test=develop

* fix bug; test=develop

* update trace kernel; test=develop

* move kernel args to trace_sig; test=develop
上级 91b074a2
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/trace_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
...@@ -161,24 +161,6 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, ...@@ -161,24 +161,6 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad,
ops::TraceGradNoNeedBufferVarsInferer); ops::TraceGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, float>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TraceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TraceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(trace) REGISTER_OP_VERSION(trace)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/trace_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class TraceCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input");
auto* out = context.Output<framework::Tensor>("Out");
const int64_t offset = context.Attr<int>("offset");
const int64_t dim1 = context.Attr<int>("axis1");
const int64_t dim2 = context.Attr<int>("axis2");
T* out_data = out->mutable_data<T>(context.GetPlace());
const framework::Tensor diag =
Diagonal<DeviceContext, T>(context, input, offset, dim1, dim2);
if (diag.numel() > 0) {
auto stream = context.cuda_device_context().stream();
std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size());
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
context.cuda_device_context(), diag, out, kps::IdentityFunctor<T>(),
reduce_dims, stream);
} else {
math::SetConstant<DeviceContext, T> functor;
functor(context.device_context<DeviceContext>(), out, static_cast<T>(0));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace platform = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
trace, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
platform::float16>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
platform::float16>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/trace_grad_kernel.h"
#include "paddle/pten/kernels/impl/trace_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
PT_REGISTER_KERNEL(trace_grad,
CPU,
ALL_LAYOUT,
pten::TraceGradKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/trace_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/trace_kernel_impl.h"
namespace pten {
template <typename T, typename Context>
void TraceKernel(const Context& ctx,
const DenseTensor& x,
int offset,
int axis1,
int axis2,
DenseTensor* out) {
auto output_dims = out->dims();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
const DenseTensor diag = Diagonal<T, Context>(ctx, &x, offset, axis1, axis2);
if (diag.numel() > 0) {
auto x = paddle::framework::EigenMatrix<T>::Reshape(diag,
diag.dims().size() - 1);
auto output = paddle::framework::EigenVector<T>::Flatten(*out);
auto reduce_dim = Eigen::array<int, 1>({1});
output.device(*ctx.eigen_device()) = x.sum(reduce_dim);
out->Resize(output_dims);
} else {
std::fill(out_data, out_data + out->numel(), static_cast<T>(0));
}
}
} // namespace pten
PT_REGISTER_KERNEL(trace,
CPU,
ALL_LAYOUT,
pten::TraceKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/impl/trace_kernel_impl.h"
#include "paddle/pten/kernels/trace_grad_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
PT_REGISTER_KERNEL(trace_grad,
GPU,
ALL_LAYOUT,
pten::TraceGradKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/gpu/reduce.h"
#include "paddle/pten/kernels/impl/trace_kernel_impl.h"
#include "paddle/pten/kernels/trace_kernel.h"
namespace pten {
template <typename T, typename Context>
void TraceKernel(const Context& ctx,
const DenseTensor& x,
int offset,
int axis1,
int axis2,
DenseTensor* out) {
T* out_data = out->mutable_data<T>(ctx.GetPlace());
auto diag = Diagonal<T, Context>(ctx, &x, offset, axis1, axis2);
if (diag.numel() > 0) {
auto stream = ctx.stream();
std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size());
kernels::
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, diag, out, kps::IdentityFunctor<T>(), reduce_dims, stream);
} else {
paddle::operators::math::SetConstant<Context, T> functor;
functor(ctx, out, static_cast<T>(0));
}
}
} // namespace pten
PT_REGISTER_KERNEL(trace,
GPU,
ALL_LAYOUT,
pten::TraceKernel,
float,
double,
int,
int64_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -13,20 +13,26 @@ ...@@ -13,20 +13,26 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include <algorithm> #include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace pten {
namespace operators {
template <typename T> template <typename T>
struct DiagonalFunctor { struct DiagonalFunctor {
DiagonalFunctor(const T* input, const int64_t* diag_stride, DiagonalFunctor(const T* input,
const int64_t* ret_strides, int64_t pos, int64_t dim_size, const int64_t* diag_stride,
const int64_t* ret_strides,
int64_t pos,
int64_t dim_size,
T* diag) T* diag)
: input_(input), : input_(input),
diag_stride_(diag_stride), diag_stride_(diag_stride),
...@@ -55,9 +61,15 @@ struct DiagonalFunctor { ...@@ -55,9 +61,15 @@ struct DiagonalFunctor {
template <typename T> template <typename T>
struct TraceGradFunctor { struct TraceGradFunctor {
TraceGradFunctor(const T* d_out, const int64_t* out_stride, TraceGradFunctor(const T* d_out,
const int64_t* x_strides, int64_t pos, int64_t dim_size, const int64_t* out_stride,
int64_t dim1, int64_t dim2, int64_t diag_size, T* d_x) const int64_t* x_strides,
int64_t pos,
int64_t dim_size,
int64_t dim1,
int64_t dim2,
int64_t diag_size,
T* d_x)
: d_out_(d_out), : d_out_(d_out),
out_stride_(out_stride), out_stride_(out_stride),
x_strides_(x_strides), x_strides_(x_strides),
...@@ -101,10 +113,12 @@ struct TraceGradFunctor { ...@@ -101,10 +113,12 @@ struct TraceGradFunctor {
T* d_x_; T* d_x_;
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
framework::Tensor Diagonal(const framework::ExecutionContext& context, DenseTensor Diagonal(const DeviceContext& context,
const framework::Tensor* input, int64_t offset, const DenseTensor* input,
int64_t dim1, int64_t dim2) { int64_t offset,
int64_t dim1,
int64_t dim2) {
auto* input_data = input->data<T>(); auto* input_data = input->data<T>();
auto input_dims = input->dims(); auto input_dims = input->dims();
auto input_stride = framework::stride(input_dims); auto input_stride = framework::stride(input_dims);
...@@ -138,7 +152,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, ...@@ -138,7 +152,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
} }
ret_strides.push_back(stride1 + stride2); ret_strides.push_back(stride1 + stride2);
ret_dims.push_back(diag_size); ret_dims.push_back(diag_size);
framework::Tensor diag; DenseTensor diag;
framework::DDim diag_dims = framework::make_ddim(ret_dims); framework::DDim diag_dims = framework::make_ddim(ret_dims);
auto dig_stride = framework::stride(diag_dims); auto dig_stride = framework::stride(diag_dims);
auto diag_data = diag.mutable_data<T>(diag_dims, context.GetPlace()); auto diag_data = diag.mutable_data<T>(diag_dims, context.GetPlace());
...@@ -155,10 +169,10 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, ...@@ -155,10 +169,10 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
const auto* ret_arr = ret_strides.data(); const auto* ret_arr = ret_strides.data();
#endif #endif
auto& dev_ctx = context.template device_context<DeviceContext>(); // auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel()); paddle::platform::ForRange<DeviceContext> for_range(context, diag.numel());
DiagonalFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size, DiagonalFunctor<T> functor(
diag_data); input_data, diag_arr, ret_arr, pos, dim_size, diag_data);
for_range(functor); for_range(functor);
return diag; return diag;
} else { } else {
...@@ -166,99 +180,68 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, ...@@ -166,99 +180,68 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename Context>
class TraceKernel : public framework::OpKernel<T> { void TraceGradKernel(const Context& ctx,
public: const DenseTensor& out_grad,
void Compute(const framework::ExecutionContext& context) const override { const DenseTensor& x,
auto* input = context.Input<framework::Tensor>("Input"); int offset,
auto* out = context.Output<framework::Tensor>("Out"); int axis1,
int axis2,
DenseTensor* in_grad) {
auto input_dims = in_grad->dims();
auto input_stride = framework::stride(input_dims);
auto output_dims = out_grad.dims();
auto output_stride = framework::stride(output_dims);
const int64_t offset = context.Attr<int>("offset"); auto* out_data = out_grad.data<T>();
const int64_t dim1 = context.Attr<int>("axis1"); T* x_data = in_grad->mutable_data<T>(ctx.GetPlace());
const int64_t dim2 = context.Attr<int>("axis2");
auto output_dims = out->dims(); paddle::operators::math::SetConstant<Context, T> set_zero;
T* out_data = out->mutable_data<T>(context.GetPlace()); set_zero(ctx, in_grad, static_cast<T>(0.0));
auto dim1 = axis1;
auto dim2 = axis2;
auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1;
auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2;
auto len1 = input_dims[std::min(dim1_, dim2_)];
auto len2 = input_dims[std::max(dim1_, dim2_)];
auto stride1 = input_stride[std::min(dim1_, dim2_)];
auto stride2 = input_stride[std::max(dim1_, dim2_)];
const framework::Tensor diag = int offset_stride = 0;
Diagonal<DeviceContext, T>(context, input, offset, dim1, dim2); if (offset >= 0) {
if (diag.numel() > 0) { offset_stride = stride2;
auto x = framework::EigenMatrix<T>::Reshape(diag, diag.dims().size() - 1); len2 -= offset;
auto output = framework::EigenVector<T>::Flatten(*out); } else {
auto& place = offset_stride = stride1;
*context.template device_context<DeviceContext>().eigen_device(); len1 += offset;
auto reduce_dim = Eigen::array<int, 1>({1});
output.device(place) = x.sum(reduce_dim);
out->Resize(output_dims);
} else {
std::fill(out_data, out_data + out->numel(), static_cast<T>(0));
}
} }
}; int64_t diag_size = len2 < len1 ? len2 : len1;
int64_t pos = std::abs(offset) * offset_stride;
template <typename DeviceContext, typename T> if (diag_size > 0) {
class TraceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
int64_t offset = context.Attr<int>("offset");
int64_t dim1 = context.Attr<int>("axis1");
int64_t dim2 = context.Attr<int>("axis2");
auto input_dims = d_x->dims();
auto input_stride = framework::stride(input_dims);
auto output_dims = d_out->dims();
auto output_stride = framework::stride(output_dims);
auto* out_data = d_out->data<T>();
T* x_data = d_x->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, d_x, static_cast<T>(0.0));
auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1;
auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2;
auto len1 = input_dims[std::min(dim1_, dim2_)];
auto len2 = input_dims[std::max(dim1_, dim2_)];
auto stride1 = input_stride[std::min(dim1_, dim2_)];
auto stride2 = input_stride[std::max(dim1_, dim2_)];
int offset_stride = 0;
if (offset >= 0) {
offset_stride = stride2;
len2 -= offset;
} else {
offset_stride = stride1;
len1 += offset;
}
int64_t diag_size = len2 < len1 ? len2 : len1;
int64_t pos = std::abs(offset) * offset_stride;
if (diag_size > 0) {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> output_vec(vectorize(output_stride)); thrust::device_vector<int64_t> output_vec(vectorize(output_stride));
const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data()); const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data());
thrust::device_vector<int64_t> input_vec(vectorize(input_stride)); thrust::device_vector<int64_t> input_vec(vectorize(input_stride));
const int64_t* input_arr = thrust::raw_pointer_cast(input_vec.data()); const int64_t* input_arr = thrust::raw_pointer_cast(input_vec.data());
#else #else
const auto* output_arr = output_stride.Get(); const auto* output_arr = output_stride.Get();
const auto* input_arr = input_stride.Get(); const auto* input_arr = input_stride.Get();
#endif #endif
platform::ForRange<DeviceContext> for_range(dev_ctx, d_x->numel()); paddle::platform::ForRange<Context> for_range(ctx, in_grad->numel());
TraceGradFunctor<T> functor(out_data, output_arr, input_arr, pos, TraceGradFunctor<T> functor(out_data,
input_dims.size(), dim1_, dim2_, diag_size, output_arr,
x_data); input_arr,
for_range(functor); pos,
} input_dims.size(),
dim1_,
dim2_,
diag_size,
x_data);
for_range(functor);
} }
}; }
} // namespace operators } // namespace pten
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void TraceGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
int offset,
int axis1,
int axis2,
DenseTensor* in_grad);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void TraceKernel(const Context& ctx,
const DenseTensor& x,
int offset,
int axis1,
int axis2,
DenseTensor* out);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"trace", {"Input"}, {"offset", "axis1", "axis2"}, {"Out"});
}
KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("trace_grad",
{GradVarName("Out"), "Input"},
{"offset", "axis1", "axis2"},
{GradVarName("Input")});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(trace, pten::TraceOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(trace_grad, pten::TraceGradOpArgumentMapping);
...@@ -21,6 +21,7 @@ import paddle.nn.functional as F ...@@ -21,6 +21,7 @@ import paddle.nn.functional as F
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.tensor as tensor import paddle.tensor as tensor
import paddle
class TestTraceOp(OpTest): class TestTraceOp(OpTest):
...@@ -86,4 +87,5 @@ class TestTraceAPICase(unittest.TestCase): ...@@ -86,4 +87,5 @@ class TestTraceAPICase(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册