diff --git a/paddle/fluid/operators/trace_op.cc b/paddle/fluid/operators/trace_op.cc index de71a089b692a9f2ea4c3c59c1fa85cbc47b1e33..aabad64c894df516ceeb8c3f7f753f3aa4fc70d3 100644 --- a/paddle/fluid/operators/trace_op.cc +++ b/paddle/fluid/operators/trace_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/trace_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -161,24 +161,6 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, ops::TraceGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL( - trace, ops::TraceKernel, - ops::TraceKernel, - ops::TraceKernel, - ops::TraceKernel, - ops::TraceKernel>, - ops::TraceKernel>); -REGISTER_OP_CPU_KERNEL( - trace_grad, ops::TraceGradKernel, - ops::TraceGradKernel, - ops::TraceGradKernel, - ops::TraceGradKernel, - ops::TraceGradKernel>, - ops::TraceGradKernel>); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(trace) diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu deleted file mode 100644 index 3d8a60dd65fc6fe3cf85b4c507ffffff647c414c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/trace_op.cu +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/operators/trace_op.h" - -namespace paddle { -namespace operators { - -template -class TraceCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - auto* out = context.Output("Out"); - - const int64_t offset = context.Attr("offset"); - const int64_t dim1 = context.Attr("axis1"); - const int64_t dim2 = context.Attr("axis2"); - - T* out_data = out->mutable_data(context.GetPlace()); - const framework::Tensor diag = - Diagonal(context, input, offset, dim1, dim2); - if (diag.numel() > 0) { - auto stream = context.cuda_device_context().stream(); - std::vector reduce_dims; - reduce_dims.push_back(out->dims().size()); - TensorReduceImpl>( - context.cuda_device_context(), diag, out, kps::IdentityFunctor(), - reduce_dims, stream); - } else { - math::SetConstant functor; - functor(context.device_context(), out, static_cast(0)); - } - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace platform = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - trace, ops::TraceCUDAKernel, - ops::TraceCUDAKernel, - ops::TraceCUDAKernel, - ops::TraceCUDAKernel, - ops::TraceCUDAKernel, - ops::TraceCUDAKernel>, - ops::TraceCUDAKernel>); -REGISTER_OP_CUDA_KERNEL( - trace_grad, ops::TraceGradKernel, - ops::TraceGradKernel, - ops::TraceGradKernel, - ops::TraceGradKernel, - ops::TraceGradKernel, - ops::TraceGradKernel>, - ops::TraceGradKernel>); diff --git a/paddle/pten/kernels/cpu/trace_grad_kernel.cc b/paddle/pten/kernels/cpu/trace_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..136b941ea8b0f0d8d11ba458c35f118e1a6e685e --- /dev/null +++ b/paddle/pten/kernels/cpu/trace_grad_kernel.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/trace_grad_kernel.h" +#include "paddle/pten/kernels/impl/trace_kernel_impl.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +PT_REGISTER_KERNEL(trace_grad, + CPU, + ALL_LAYOUT, + pten::TraceGradKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/trace_kernel.cc b/paddle/pten/kernels/cpu/trace_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..4064b752ef4ca87994f6e365cf89349a74c57adc --- /dev/null +++ b/paddle/pten/kernels/cpu/trace_kernel.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/trace_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/trace_kernel_impl.h" + +namespace pten { + +template +void TraceKernel(const Context& ctx, + const DenseTensor& x, + int offset, + int axis1, + int axis2, + DenseTensor* out) { + auto output_dims = out->dims(); + + T* out_data = out->mutable_data(ctx.GetPlace()); + + const DenseTensor diag = Diagonal(ctx, &x, offset, axis1, axis2); + if (diag.numel() > 0) { + auto x = paddle::framework::EigenMatrix::Reshape(diag, + diag.dims().size() - 1); + auto output = paddle::framework::EigenVector::Flatten(*out); + auto reduce_dim = Eigen::array({1}); + output.device(*ctx.eigen_device()) = x.sum(reduce_dim); + out->Resize(output_dims); + } else { + std::fill(out_data, out_data + out->numel(), static_cast(0)); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(trace, + CPU, + ALL_LAYOUT, + pten::TraceKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/trace_grad_kernel.cu b/paddle/pten/kernels/gpu/trace_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..b1b22e5ce549609b1f4d4adb843f726c0eefd7a3 --- /dev/null +++ b/paddle/pten/kernels/gpu/trace_grad_kernel.cu @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/impl/trace_kernel_impl.h" +#include "paddle/pten/kernels/trace_grad_kernel.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +PT_REGISTER_KERNEL(trace_grad, + GPU, + ALL_LAYOUT, + pten::TraceGradKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/trace_kernel.cu b/paddle/pten/kernels/gpu/trace_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..b0bfb3a8a51865df55a507506fe870016ac9181e --- /dev/null +++ b/paddle/pten/kernels/gpu/trace_kernel.cu @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/gpu/reduce.h" +#include "paddle/pten/kernels/impl/trace_kernel_impl.h" +#include "paddle/pten/kernels/trace_kernel.h" + +namespace pten { + +template +void TraceKernel(const Context& ctx, + const DenseTensor& x, + int offset, + int axis1, + int axis2, + DenseTensor* out) { + T* out_data = out->mutable_data(ctx.GetPlace()); + auto diag = Diagonal(ctx, &x, offset, axis1, axis2); + if (diag.numel() > 0) { + auto stream = ctx.stream(); + std::vector reduce_dims; + reduce_dims.push_back(out->dims().size()); + kernels:: + TensorReduceFunctorImpl>( + ctx, diag, out, kps::IdentityFunctor(), reduce_dims, stream); + } else { + paddle::operators::math::SetConstant functor; + functor(ctx, out, static_cast(0)); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(trace, + GPU, + ALL_LAYOUT, + pten::TraceKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/fluid/operators/trace_op.h b/paddle/pten/kernels/impl/trace_kernel_impl.h similarity index 50% rename from paddle/fluid/operators/trace_op.h rename to paddle/pten/kernels/impl/trace_kernel_impl.h index ca9439cbed97ddb02e2e6eaa2fb89628e738576e..4dbba9bc69e616c08fc050afc027421568ea5647 100644 --- a/paddle/fluid/operators/trace_op.h +++ b/paddle/pten/kernels/impl/trace_kernel_impl.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,20 +13,26 @@ // limitations under the License. #pragma once + +#if defined(__NVCC__) || defined(__HIPCC__) +#include +#include +#endif + #include -#include + #include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" -namespace paddle { -namespace operators { - +namespace pten { template struct DiagonalFunctor { - DiagonalFunctor(const T* input, const int64_t* diag_stride, - const int64_t* ret_strides, int64_t pos, int64_t dim_size, + DiagonalFunctor(const T* input, + const int64_t* diag_stride, + const int64_t* ret_strides, + int64_t pos, + int64_t dim_size, T* diag) : input_(input), diag_stride_(diag_stride), @@ -55,9 +61,15 @@ struct DiagonalFunctor { template struct TraceGradFunctor { - TraceGradFunctor(const T* d_out, const int64_t* out_stride, - const int64_t* x_strides, int64_t pos, int64_t dim_size, - int64_t dim1, int64_t dim2, int64_t diag_size, T* d_x) + TraceGradFunctor(const T* d_out, + const int64_t* out_stride, + const int64_t* x_strides, + int64_t pos, + int64_t dim_size, + int64_t dim1, + int64_t dim2, + int64_t diag_size, + T* d_x) : d_out_(d_out), out_stride_(out_stride), x_strides_(x_strides), @@ -101,10 +113,12 @@ struct TraceGradFunctor { T* d_x_; }; -template -framework::Tensor Diagonal(const framework::ExecutionContext& context, - const framework::Tensor* input, int64_t offset, - int64_t dim1, int64_t dim2) { +template +DenseTensor Diagonal(const DeviceContext& context, + const DenseTensor* input, + int64_t offset, + int64_t dim1, + int64_t dim2) { auto* input_data = input->data(); auto input_dims = input->dims(); auto input_stride = framework::stride(input_dims); @@ -138,7 +152,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, } ret_strides.push_back(stride1 + stride2); ret_dims.push_back(diag_size); - framework::Tensor diag; + DenseTensor diag; framework::DDim diag_dims = framework::make_ddim(ret_dims); auto dig_stride = framework::stride(diag_dims); auto diag_data = diag.mutable_data(diag_dims, context.GetPlace()); @@ -155,10 +169,10 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, const auto* ret_arr = ret_strides.data(); #endif - auto& dev_ctx = context.template device_context(); - platform::ForRange for_range(dev_ctx, diag.numel()); - DiagonalFunctor functor(input_data, diag_arr, ret_arr, pos, dim_size, - diag_data); + // auto& dev_ctx = context.template device_context(); + paddle::platform::ForRange for_range(context, diag.numel()); + DiagonalFunctor functor( + input_data, diag_arr, ret_arr, pos, dim_size, diag_data); for_range(functor); return diag; } else { @@ -166,99 +180,68 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, } } -template -class TraceKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - auto* out = context.Output("Out"); +template +void TraceGradKernel(const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + int offset, + int axis1, + int axis2, + DenseTensor* in_grad) { + auto input_dims = in_grad->dims(); + auto input_stride = framework::stride(input_dims); + auto output_dims = out_grad.dims(); + auto output_stride = framework::stride(output_dims); - const int64_t offset = context.Attr("offset"); - const int64_t dim1 = context.Attr("axis1"); - const int64_t dim2 = context.Attr("axis2"); + auto* out_data = out_grad.data(); + T* x_data = in_grad->mutable_data(ctx.GetPlace()); - auto output_dims = out->dims(); + paddle::operators::math::SetConstant set_zero; - T* out_data = out->mutable_data(context.GetPlace()); + set_zero(ctx, in_grad, static_cast(0.0)); + auto dim1 = axis1; + auto dim2 = axis2; + auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1; + auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2; + auto len1 = input_dims[std::min(dim1_, dim2_)]; + auto len2 = input_dims[std::max(dim1_, dim2_)]; + auto stride1 = input_stride[std::min(dim1_, dim2_)]; + auto stride2 = input_stride[std::max(dim1_, dim2_)]; - const framework::Tensor diag = - Diagonal(context, input, offset, dim1, dim2); - if (diag.numel() > 0) { - auto x = framework::EigenMatrix::Reshape(diag, diag.dims().size() - 1); - auto output = framework::EigenVector::Flatten(*out); - auto& place = - *context.template device_context().eigen_device(); - auto reduce_dim = Eigen::array({1}); - output.device(place) = x.sum(reduce_dim); - out->Resize(output_dims); - } else { - std::fill(out_data, out_data + out->numel(), static_cast(0)); - } + int offset_stride = 0; + if (offset >= 0) { + offset_stride = stride2; + len2 -= offset; + } else { + offset_stride = stride1; + len1 += offset; } -}; - -template -class TraceGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* d_out = - context.Input(framework::GradVarName("Out")); - auto* d_x = - context.Output(framework::GradVarName("Input")); - - int64_t offset = context.Attr("offset"); - int64_t dim1 = context.Attr("axis1"); - int64_t dim2 = context.Attr("axis2"); - - auto input_dims = d_x->dims(); - auto input_stride = framework::stride(input_dims); - auto output_dims = d_out->dims(); - auto output_stride = framework::stride(output_dims); - - auto* out_data = d_out->data(); - T* x_data = d_x->mutable_data(context.GetPlace()); - - math::SetConstant set_zero; - auto& dev_ctx = context.template device_context(); - set_zero(dev_ctx, d_x, static_cast(0.0)); - - auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1; - auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2; - auto len1 = input_dims[std::min(dim1_, dim2_)]; - auto len2 = input_dims[std::max(dim1_, dim2_)]; - auto stride1 = input_stride[std::min(dim1_, dim2_)]; - auto stride2 = input_stride[std::max(dim1_, dim2_)]; - - int offset_stride = 0; - if (offset >= 0) { - offset_stride = stride2; - len2 -= offset; - } else { - offset_stride = stride1; - len1 += offset; - } - int64_t diag_size = len2 < len1 ? len2 : len1; - int64_t pos = std::abs(offset) * offset_stride; - if (diag_size > 0) { + int64_t diag_size = len2 < len1 ? len2 : len1; + int64_t pos = std::abs(offset) * offset_stride; + if (diag_size > 0) { #if defined(__NVCC__) || defined(__HIPCC__) - thrust::device_vector output_vec(vectorize(output_stride)); - const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data()); - thrust::device_vector input_vec(vectorize(input_stride)); - const int64_t* input_arr = thrust::raw_pointer_cast(input_vec.data()); + thrust::device_vector output_vec(vectorize(output_stride)); + const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data()); + thrust::device_vector input_vec(vectorize(input_stride)); + const int64_t* input_arr = thrust::raw_pointer_cast(input_vec.data()); #else - const auto* output_arr = output_stride.Get(); - const auto* input_arr = input_stride.Get(); + const auto* output_arr = output_stride.Get(); + const auto* input_arr = input_stride.Get(); #endif - platform::ForRange for_range(dev_ctx, d_x->numel()); - TraceGradFunctor functor(out_data, output_arr, input_arr, pos, - input_dims.size(), dim1_, dim2_, diag_size, - x_data); - for_range(functor); - } + paddle::platform::ForRange for_range(ctx, in_grad->numel()); + TraceGradFunctor functor(out_data, + output_arr, + input_arr, + pos, + input_dims.size(), + dim1_, + dim2_, + diag_size, + x_data); + for_range(functor); } -}; +} -} // namespace operators -} // namespace paddle +} // namespace pten diff --git a/paddle/pten/kernels/trace_grad_kernel.h b/paddle/pten/kernels/trace_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9dad4e6ac90a2b1dd795ba66383a2ef233b4dc88 --- /dev/null +++ b/paddle/pten/kernels/trace_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void TraceGradKernel(const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + int offset, + int axis1, + int axis2, + DenseTensor* in_grad); + +} // namespace pten diff --git a/paddle/pten/kernels/trace_kernel.h b/paddle/pten/kernels/trace_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7a68e32508a9b1b571e71857c4c486d999c8dfb5 --- /dev/null +++ b/paddle/pten/kernels/trace_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void TraceKernel(const Context& ctx, + const DenseTensor& x, + int offset, + int axis1, + int axis2, + DenseTensor* out); + +} // namespace pten diff --git a/paddle/pten/ops/compat/trace_sig.cc b/paddle/pten/ops/compat/trace_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..a36beaaee7a9707f37f74ac31d52561561e27f4d --- /dev/null +++ b/paddle/pten/ops/compat/trace_sig.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "trace", {"Input"}, {"offset", "axis1", "axis2"}, {"Out"}); +} + +KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("trace_grad", + {GradVarName("Out"), "Input"}, + {"offset", "axis1", "axis2"}, + {GradVarName("Input")}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(trace, pten::TraceOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(trace_grad, pten::TraceGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_trace_op.py b/python/paddle/fluid/tests/unittests/test_trace_op.py index 7441ff24329fa5a5f57e253ba480d1c7d9ab647c..3320b240e56155f8bfc2f6d8e43306f36e651869 100644 --- a/python/paddle/fluid/tests/unittests/test_trace_op.py +++ b/python/paddle/fluid/tests/unittests/test_trace_op.py @@ -21,6 +21,7 @@ import paddle.nn.functional as F import paddle.fluid as fluid import paddle.fluid.core as core import paddle.tensor as tensor +import paddle class TestTraceOp(OpTest): @@ -86,4 +87,5 @@ class TestTraceAPICase(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main()