From 69394d1a155855b3a63c8cc229299236a278c0ed Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Wed, 9 Feb 2022 14:13:46 +0800 Subject: [PATCH] Revert "Move trace op to pten (#39227)" This reverts commit d7dddf9486bf7e39456be44b5a747db16644e818. --- paddle/fluid/operators/trace_op.cc | 20 +- paddle/fluid/operators/trace_op.cu | 77 +++++++ .../operators/trace_op.h} | 191 ++++++++++-------- paddle/pten/kernels/cpu/trace_grad_kernel.cc | 31 --- paddle/pten/kernels/cpu/trace_kernel.cc | 58 ------ paddle/pten/kernels/gpu/trace_grad_kernel.cu | 31 --- paddle/pten/kernels/gpu/trace_kernel.cu | 57 ------ paddle/pten/kernels/trace_grad_kernel.h | 30 --- paddle/pten/kernels/trace_kernel.h | 29 --- paddle/pten/ops/compat/trace_sig.cc | 34 ---- .../fluid/tests/unittests/test_trace_op.py | 2 - 11 files changed, 200 insertions(+), 360 deletions(-) create mode 100644 paddle/fluid/operators/trace_op.cu rename paddle/{pten/kernels/impl/trace_kernel_impl.h => fluid/operators/trace_op.h} (50%) delete mode 100644 paddle/pten/kernels/cpu/trace_grad_kernel.cc delete mode 100644 paddle/pten/kernels/cpu/trace_kernel.cc delete mode 100644 paddle/pten/kernels/gpu/trace_grad_kernel.cu delete mode 100644 paddle/pten/kernels/gpu/trace_kernel.cu delete mode 100644 paddle/pten/kernels/trace_grad_kernel.h delete mode 100644 paddle/pten/kernels/trace_kernel.h delete mode 100644 paddle/pten/ops/compat/trace_sig.cc diff --git a/paddle/fluid/operators/trace_op.cc b/paddle/fluid/operators/trace_op.cc index aabad64c894..de71a089b69 100644 --- a/paddle/fluid/operators/trace_op.cc +++ b/paddle/fluid/operators/trace_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/trace_op.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -161,6 +161,24 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, ops::TraceGradNoNeedBufferVarsInferer); +REGISTER_OP_CPU_KERNEL( + trace, ops::TraceKernel, + ops::TraceKernel, + ops::TraceKernel, + ops::TraceKernel, + ops::TraceKernel>, + ops::TraceKernel>); +REGISTER_OP_CPU_KERNEL( + trace_grad, ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel>, + ops::TraceGradKernel>); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(trace) diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu new file mode 100644 index 00000000000..3d8a60dd65f --- /dev/null +++ b/paddle/fluid/operators/trace_op.cu @@ -0,0 +1,77 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/trace_op.h" + +namespace paddle { +namespace operators { + +template +class TraceCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* out = context.Output("Out"); + + const int64_t offset = context.Attr("offset"); + const int64_t dim1 = context.Attr("axis1"); + const int64_t dim2 = context.Attr("axis2"); + + T* out_data = out->mutable_data(context.GetPlace()); + const framework::Tensor diag = + Diagonal(context, input, offset, dim1, dim2); + if (diag.numel() > 0) { + auto stream = context.cuda_device_context().stream(); + std::vector reduce_dims; + reduce_dims.push_back(out->dims().size()); + TensorReduceImpl>( + context.cuda_device_context(), diag, out, kps::IdentityFunctor(), + reduce_dims, stream); + } else { + math::SetConstant functor; + functor(context.device_context(), out, static_cast(0)); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace platform = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + trace, ops::TraceCUDAKernel, + ops::TraceCUDAKernel, + ops::TraceCUDAKernel, + ops::TraceCUDAKernel, + ops::TraceCUDAKernel, + ops::TraceCUDAKernel>, + ops::TraceCUDAKernel>); +REGISTER_OP_CUDA_KERNEL( + trace_grad, ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel>, + ops::TraceGradKernel>); diff --git a/paddle/pten/kernels/impl/trace_kernel_impl.h b/paddle/fluid/operators/trace_op.h similarity index 50% rename from paddle/pten/kernels/impl/trace_kernel_impl.h rename to paddle/fluid/operators/trace_op.h index 4dbba9bc69e..ca9439cbed9 100644 --- a/paddle/pten/kernels/impl/trace_kernel_impl.h +++ b/paddle/fluid/operators/trace_op.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,26 +13,20 @@ // limitations under the License. #pragma once - -#if defined(__NVCC__) || defined(__HIPCC__) -#include -#include -#endif - #include - +#include #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" -namespace pten { +namespace paddle { +namespace operators { + template struct DiagonalFunctor { - DiagonalFunctor(const T* input, - const int64_t* diag_stride, - const int64_t* ret_strides, - int64_t pos, - int64_t dim_size, + DiagonalFunctor(const T* input, const int64_t* diag_stride, + const int64_t* ret_strides, int64_t pos, int64_t dim_size, T* diag) : input_(input), diag_stride_(diag_stride), @@ -61,15 +55,9 @@ struct DiagonalFunctor { template struct TraceGradFunctor { - TraceGradFunctor(const T* d_out, - const int64_t* out_stride, - const int64_t* x_strides, - int64_t pos, - int64_t dim_size, - int64_t dim1, - int64_t dim2, - int64_t diag_size, - T* d_x) + TraceGradFunctor(const T* d_out, const int64_t* out_stride, + const int64_t* x_strides, int64_t pos, int64_t dim_size, + int64_t dim1, int64_t dim2, int64_t diag_size, T* d_x) : d_out_(d_out), out_stride_(out_stride), x_strides_(x_strides), @@ -113,12 +101,10 @@ struct TraceGradFunctor { T* d_x_; }; -template -DenseTensor Diagonal(const DeviceContext& context, - const DenseTensor* input, - int64_t offset, - int64_t dim1, - int64_t dim2) { +template +framework::Tensor Diagonal(const framework::ExecutionContext& context, + const framework::Tensor* input, int64_t offset, + int64_t dim1, int64_t dim2) { auto* input_data = input->data(); auto input_dims = input->dims(); auto input_stride = framework::stride(input_dims); @@ -152,7 +138,7 @@ DenseTensor Diagonal(const DeviceContext& context, } ret_strides.push_back(stride1 + stride2); ret_dims.push_back(diag_size); - DenseTensor diag; + framework::Tensor diag; framework::DDim diag_dims = framework::make_ddim(ret_dims); auto dig_stride = framework::stride(diag_dims); auto diag_data = diag.mutable_data(diag_dims, context.GetPlace()); @@ -169,10 +155,10 @@ DenseTensor Diagonal(const DeviceContext& context, const auto* ret_arr = ret_strides.data(); #endif - // auto& dev_ctx = context.template device_context(); - paddle::platform::ForRange for_range(context, diag.numel()); - DiagonalFunctor functor( - input_data, diag_arr, ret_arr, pos, dim_size, diag_data); + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, diag.numel()); + DiagonalFunctor functor(input_data, diag_arr, ret_arr, pos, dim_size, + diag_data); for_range(functor); return diag; } else { @@ -180,68 +166,99 @@ DenseTensor Diagonal(const DeviceContext& context, } } -template -void TraceGradKernel(const Context& ctx, - const DenseTensor& out_grad, - const DenseTensor& x, - int offset, - int axis1, - int axis2, - DenseTensor* in_grad) { - auto input_dims = in_grad->dims(); - auto input_stride = framework::stride(input_dims); - auto output_dims = out_grad.dims(); - auto output_stride = framework::stride(output_dims); +template +class TraceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* out = context.Output("Out"); - auto* out_data = out_grad.data(); - T* x_data = in_grad->mutable_data(ctx.GetPlace()); + const int64_t offset = context.Attr("offset"); + const int64_t dim1 = context.Attr("axis1"); + const int64_t dim2 = context.Attr("axis2"); - paddle::operators::math::SetConstant set_zero; + auto output_dims = out->dims(); - set_zero(ctx, in_grad, static_cast(0.0)); - auto dim1 = axis1; - auto dim2 = axis2; - auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1; - auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2; - auto len1 = input_dims[std::min(dim1_, dim2_)]; - auto len2 = input_dims[std::max(dim1_, dim2_)]; - auto stride1 = input_stride[std::min(dim1_, dim2_)]; - auto stride2 = input_stride[std::max(dim1_, dim2_)]; + T* out_data = out->mutable_data(context.GetPlace()); - int offset_stride = 0; - if (offset >= 0) { - offset_stride = stride2; - len2 -= offset; - } else { - offset_stride = stride1; - len1 += offset; + const framework::Tensor diag = + Diagonal(context, input, offset, dim1, dim2); + if (diag.numel() > 0) { + auto x = framework::EigenMatrix::Reshape(diag, diag.dims().size() - 1); + auto output = framework::EigenVector::Flatten(*out); + auto& place = + *context.template device_context().eigen_device(); + auto reduce_dim = Eigen::array({1}); + output.device(place) = x.sum(reduce_dim); + out->Resize(output_dims); + } else { + std::fill(out_data, out_data + out->numel(), static_cast(0)); + } } - int64_t diag_size = len2 < len1 ? len2 : len1; - int64_t pos = std::abs(offset) * offset_stride; - if (diag_size > 0) { +}; + +template +class TraceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* d_out = + context.Input(framework::GradVarName("Out")); + auto* d_x = + context.Output(framework::GradVarName("Input")); + + int64_t offset = context.Attr("offset"); + int64_t dim1 = context.Attr("axis1"); + int64_t dim2 = context.Attr("axis2"); + + auto input_dims = d_x->dims(); + auto input_stride = framework::stride(input_dims); + auto output_dims = d_out->dims(); + auto output_stride = framework::stride(output_dims); + + auto* out_data = d_out->data(); + T* x_data = d_x->mutable_data(context.GetPlace()); + + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, d_x, static_cast(0.0)); + + auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1; + auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2; + auto len1 = input_dims[std::min(dim1_, dim2_)]; + auto len2 = input_dims[std::max(dim1_, dim2_)]; + auto stride1 = input_stride[std::min(dim1_, dim2_)]; + auto stride2 = input_stride[std::max(dim1_, dim2_)]; + + int offset_stride = 0; + if (offset >= 0) { + offset_stride = stride2; + len2 -= offset; + } else { + offset_stride = stride1; + len1 += offset; + } + int64_t diag_size = len2 < len1 ? len2 : len1; + int64_t pos = std::abs(offset) * offset_stride; + if (diag_size > 0) { #if defined(__NVCC__) || defined(__HIPCC__) - thrust::device_vector output_vec(vectorize(output_stride)); - const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data()); - thrust::device_vector input_vec(vectorize(input_stride)); - const int64_t* input_arr = thrust::raw_pointer_cast(input_vec.data()); + thrust::device_vector output_vec(vectorize(output_stride)); + const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data()); + thrust::device_vector input_vec(vectorize(input_stride)); + const int64_t* input_arr = thrust::raw_pointer_cast(input_vec.data()); #else - const auto* output_arr = output_stride.Get(); - const auto* input_arr = input_stride.Get(); + const auto* output_arr = output_stride.Get(); + const auto* input_arr = input_stride.Get(); #endif - paddle::platform::ForRange for_range(ctx, in_grad->numel()); - TraceGradFunctor functor(out_data, - output_arr, - input_arr, - pos, - input_dims.size(), - dim1_, - dim2_, - diag_size, - x_data); - for_range(functor); + platform::ForRange for_range(dev_ctx, d_x->numel()); + TraceGradFunctor functor(out_data, output_arr, input_arr, pos, + input_dims.size(), dim1_, dim2_, diag_size, + x_data); + for_range(functor); + } } -} +}; -} // namespace pten +} // namespace operators +} // namespace paddle diff --git a/paddle/pten/kernels/cpu/trace_grad_kernel.cc b/paddle/pten/kernels/cpu/trace_grad_kernel.cc deleted file mode 100644 index 136b941ea8b..00000000000 --- a/paddle/pten/kernels/cpu/trace_grad_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pten/kernels/trace_grad_kernel.h" -#include "paddle/pten/kernels/impl/trace_kernel_impl.h" - -#include "paddle/pten/backends/cpu/cpu_context.h" -#include "paddle/pten/core/kernel_registry.h" - -PT_REGISTER_KERNEL(trace_grad, - CPU, - ALL_LAYOUT, - pten::TraceGradKernel, - float, - double, - int, - int64_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/trace_kernel.cc b/paddle/pten/kernels/cpu/trace_kernel.cc deleted file mode 100644 index 4064b752ef4..00000000000 --- a/paddle/pten/kernels/cpu/trace_kernel.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pten/kernels/trace_kernel.h" -#include "paddle/pten/backends/cpu/cpu_context.h" -#include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/trace_kernel_impl.h" - -namespace pten { - -template -void TraceKernel(const Context& ctx, - const DenseTensor& x, - int offset, - int axis1, - int axis2, - DenseTensor* out) { - auto output_dims = out->dims(); - - T* out_data = out->mutable_data(ctx.GetPlace()); - - const DenseTensor diag = Diagonal(ctx, &x, offset, axis1, axis2); - if (diag.numel() > 0) { - auto x = paddle::framework::EigenMatrix::Reshape(diag, - diag.dims().size() - 1); - auto output = paddle::framework::EigenVector::Flatten(*out); - auto reduce_dim = Eigen::array({1}); - output.device(*ctx.eigen_device()) = x.sum(reduce_dim); - out->Resize(output_dims); - } else { - std::fill(out_data, out_data + out->numel(), static_cast(0)); - } -} - -} // namespace pten - -PT_REGISTER_KERNEL(trace, - CPU, - ALL_LAYOUT, - pten::TraceKernel, - float, - double, - int, - int64_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/trace_grad_kernel.cu b/paddle/pten/kernels/gpu/trace_grad_kernel.cu deleted file mode 100644 index b1b22e5ce54..00000000000 --- a/paddle/pten/kernels/gpu/trace_grad_kernel.cu +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pten/kernels/impl/trace_kernel_impl.h" -#include "paddle/pten/kernels/trace_grad_kernel.h" - -#include "paddle/pten/backends/cpu/cpu_context.h" -#include "paddle/pten/core/kernel_registry.h" - -PT_REGISTER_KERNEL(trace_grad, - GPU, - ALL_LAYOUT, - pten::TraceGradKernel, - float, - double, - int, - int64_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/trace_kernel.cu b/paddle/pten/kernels/gpu/trace_kernel.cu deleted file mode 100644 index b0bfb3a8a51..00000000000 --- a/paddle/pten/kernels/gpu/trace_kernel.cu +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pten/backends/gpu/gpu_context.h" -#include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/gpu/reduce.h" -#include "paddle/pten/kernels/impl/trace_kernel_impl.h" -#include "paddle/pten/kernels/trace_kernel.h" - -namespace pten { - -template -void TraceKernel(const Context& ctx, - const DenseTensor& x, - int offset, - int axis1, - int axis2, - DenseTensor* out) { - T* out_data = out->mutable_data(ctx.GetPlace()); - auto diag = Diagonal(ctx, &x, offset, axis1, axis2); - if (diag.numel() > 0) { - auto stream = ctx.stream(); - std::vector reduce_dims; - reduce_dims.push_back(out->dims().size()); - kernels:: - TensorReduceFunctorImpl>( - ctx, diag, out, kps::IdentityFunctor(), reduce_dims, stream); - } else { - paddle::operators::math::SetConstant functor; - functor(ctx, out, static_cast(0)); - } -} - -} // namespace pten - -PT_REGISTER_KERNEL(trace, - GPU, - ALL_LAYOUT, - pten::TraceKernel, - float, - double, - int, - int64_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} diff --git a/paddle/pten/kernels/trace_grad_kernel.h b/paddle/pten/kernels/trace_grad_kernel.h deleted file mode 100644 index 9dad4e6ac90..00000000000 --- a/paddle/pten/kernels/trace_grad_kernel.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pten/core/dense_tensor.h" - -namespace pten { - -template -void TraceGradKernel(const Context& ctx, - const DenseTensor& out_grad, - const DenseTensor& x, - int offset, - int axis1, - int axis2, - DenseTensor* in_grad); - -} // namespace pten diff --git a/paddle/pten/kernels/trace_kernel.h b/paddle/pten/kernels/trace_kernel.h deleted file mode 100644 index 7a68e32508a..00000000000 --- a/paddle/pten/kernels/trace_kernel.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pten/core/dense_tensor.h" - -namespace pten { - -template -void TraceKernel(const Context& ctx, - const DenseTensor& x, - int offset, - int axis1, - int axis2, - DenseTensor* out); - -} // namespace pten diff --git a/paddle/pten/ops/compat/trace_sig.cc b/paddle/pten/ops/compat/trace_sig.cc deleted file mode 100644 index a36beaaee7a..00000000000 --- a/paddle/pten/ops/compat/trace_sig.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pten/core/compat/op_utils.h" - -namespace pten { - -KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "trace", {"Input"}, {"offset", "axis1", "axis2"}, {"Out"}); -} - -KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("trace_grad", - {GradVarName("Out"), "Input"}, - {"offset", "axis1", "axis2"}, - {GradVarName("Input")}); -} - -} // namespace pten - -PT_REGISTER_ARG_MAPPING_FN(trace, pten::TraceOpArgumentMapping); -PT_REGISTER_ARG_MAPPING_FN(trace_grad, pten::TraceGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_trace_op.py b/python/paddle/fluid/tests/unittests/test_trace_op.py index 3320b240e56..7441ff24329 100644 --- a/python/paddle/fluid/tests/unittests/test_trace_op.py +++ b/python/paddle/fluid/tests/unittests/test_trace_op.py @@ -21,7 +21,6 @@ import paddle.nn.functional as F import paddle.fluid as fluid import paddle.fluid.core as core import paddle.tensor as tensor -import paddle class TestTraceOp(OpTest): @@ -87,5 +86,4 @@ class TestTraceAPICase(unittest.TestCase): if __name__ == "__main__": - paddle.enable_static() unittest.main() -- GitLab