diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index f559027fdd4b02732cfa5d73114c7d9da8bef21c..5d1851fb85aa2fa04f59ca440aafed14644d0c06 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -475,6 +475,54 @@ std::tuple momentum_impl( return api_output; } +std::vector unbind_impl(const Tensor& input, int axis) { + auto kernel_key_set = ParseKernelKeyByInputArgs(input); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + + Backend kernel_backend = kernel_key.backend(); + DataLayout kernel_layout = kernel_key.layout(); + DataType kernel_data_type = kernel_key.dtype(); + + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "unbind", {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << "unbind API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + auto dense_input = PrepareData(input, kernel.InputAt(0), {}); + + // Calculate the number of out tensors + auto input_shape = input.dims(); + if (axis < 0) { + axis = input_shape.size() + axis; + } + auto out_num = input_shape[axis]; + + std::vector out; + auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out); + std::vector meta_outs; + meta_outs.reserve(out_num); + std::vector meta_out_ptrs; + meta_out_ptrs.reserve(out_num); + for (int64_t i = 0; i < out_num; ++i) { + meta_outs.push_back(dense_outs[i]); + meta_out_ptrs.push_back(&meta_outs.back()); + } + + phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs); + + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + int, + std::vector&); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs); + + return out; +} + ////////////////// Backward(grad) api impls ////////////////////// // TODO(chenweihang): the original sum grad op can support higher-level diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index 4745782d914cabb1310695a05ebcf9914f53aa24..80ace229316a92b3b190557bdee3fc70a2ebe2c4 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include + #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" @@ -73,6 +75,8 @@ std::tuple momentum_impl( bool multi_precision, float rescale_grad); +std::vector unbind_impl(const Tensor& input, int axis); + ////////////////// Backward(grad) api impls ////////////////////// std::vector add_n_grad_impl(const std::vector& x, diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index 1ca6e2ce0bb9a31365a176b4ca7b595923acb19f..6d97dc7657f00616c8970e86b03f09f35eaa4a0f 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -14,18 +14,46 @@ limitations under the License. */ #include "paddle/phi/api/lib/kernel_dispatch.h" -#include "paddle/phi/api/include/context_pool.h" -#include "paddle/phi/core/compat/convert_utils.h" #ifdef _MSC_VER #include #endif +#include "paddle/phi/api/include/context_pool.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/string_tensor_utils.h" +#include "paddle/phi/core/tensor_utils.h" + namespace paddle { namespace experimental { namespace detail { +// We need judge whether the allocation is nullptr, +// whether the allocation is initialized, wo we need GetHolder method +bool HasAllocation(const phi::TensorBase& t) { + if (phi::DenseTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t)) != nullptr; + } else if (phi::SelectedRows::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t).value()) != nullptr; + } else if (phi::SparseCsrTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t) + .non_zero_elements()) != nullptr; + } else if (phi::SparseCooTensor::classof(&t)) { + return phi::DenseTensorUtils::GetHolder( + static_cast(t) + .non_zero_elements()) != nullptr; + } else if (phi::StringTensor::classof(&t)) { + return phi::StringTensorUtils::GetHolder( + static_cast(t)) != nullptr; + } else { + return false; + } +} + BackendSet GetTensorBackendSet(const phi::TensorBase& t) { - if (t.initialized()) { + if (HasAllocation(t)) { BackendSet backend_set(phi::TransToPhiBackend(t.place())); switch (t.layout()) { case DataLayout::MKLDNN: diff --git a/paddle/phi/core/string_tensor_utils.h b/paddle/phi/core/string_tensor_utils.h index c1b0d09647d91c0529e0db952937d5585be9e9d9..777a24c9adfe15bf3dfafeda57c734b6c1c9a665 100644 --- a/paddle/phi/core/string_tensor_utils.h +++ b/paddle/phi/core/string_tensor_utils.h @@ -23,6 +23,11 @@ class StringTensorUtils { static StringTensorMeta* GetMutableMeta(StringTensor* tensor) { return &(tensor->meta_); } + + static const std::shared_ptr& GetHolder( + const StringTensor& tensor) { + return tensor.holder_; + } }; } // namespace phi diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index 676a590ecbce23a107bcc891c37ac69406854035..abf8aeff4d3ab047809bad8ba902075824cf263e 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -25,6 +25,11 @@ class DenseTensorUtils { return &(tensor->meta_); } + static const std::shared_ptr& GetHolder( + const DenseTensor& tensor) { + return tensor.holder_; + } + static DenseTensor Slice(const DenseTensor& tensor, int64_t begin_idx, int64_t end_idx) { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index e0ea637074c2027402317c49a210abd3325f83f5..0fedcca255c90f8336fd347a2a2d7280d9f89d57 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2429,7 +2429,7 @@ void TransposeGradInferMeta(const MetaTensor& x, void UnbindInferMeta(const MetaTensor& x, int axis, - std::vector* outs) { + std::vector outs) { auto in_dims = x.dims(); std::vector out_dim; axis = axis < 0 ? in_dims.size() + axis : axis; @@ -2438,11 +2438,11 @@ void UnbindInferMeta(const MetaTensor& x, } auto out_dims = phi::make_ddim(out_dim); - for (size_t i = 0; i < outs->size(); ++i) { - (*outs)[i].set_dtype(x.dtype()); - (*outs)[i].set_dims(out_dims); - (*outs)[i].set_layout(x.layout()); - (*outs)[i].share_lod(x); + for (size_t i = 0; i < outs.size(); ++i) { + outs[i]->set_dtype(x.dtype()); + outs[i]->set_dims(out_dims); + outs[i]->set_layout(x.layout()); + outs[i]->share_lod(x); } } diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 5106c6f4487336741fda8855caab8e6628a0c2e9..1d69c9504d9cdaae410a7368602684acbc1fa2ae 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -365,7 +365,7 @@ void TrilTriuInferMeta(const MetaTensor& x, void UnbindInferMeta(const MetaTensor& x, int axis, - std::vector* outs); + std::vector outs); void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc b/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc index 4247e597acef4aac14f93066a3ea6232734e0c8c..10280082619194a4763ae995526c4a54ee8dfd06 100644 --- a/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc +++ b/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc @@ -21,10 +21,141 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/fluid/framework/generator.h" - namespace phi { +// reference: https://gist.github.com/lakshayg/d80172fe5ae3c5d2c2aedb53c250320e +template +T Erfinv(T x) { + if (x < -1 || x > 1) { + return std::numeric_limits::quiet_NaN(); + } else if (x == 1.0) { + return std::numeric_limits::infinity(); + } else if (x == -1.0) { + return -std::numeric_limits::infinity(); + } + + const T LN2 = 6.931471805599453094172321214581e-1; + + const T A0 = 1.1975323115670912564578e0; + const T A1 = 4.7072688112383978012285e1; + const T A2 = 6.9706266534389598238465e2; + const T A3 = 4.8548868893843886794648e3; + const T A4 = 1.6235862515167575384252e4; + const T A5 = 2.3782041382114385731252e4; + const T A6 = 1.1819493347062294404278e4; + const T A7 = 8.8709406962545514830200e2; + + const T B0 = 1.0000000000000000000e0; + const T B1 = 4.2313330701600911252e1; + const T B2 = 6.8718700749205790830e2; + const T B3 = 5.3941960214247511077e3; + const T B4 = 2.1213794301586595867e4; + const T B5 = 3.9307895800092710610e4; + const T B6 = 2.8729085735721942674e4; + const T B7 = 5.2264952788528545610e3; + + const T C0 = 1.42343711074968357734e0; + const T C1 = 4.63033784615654529590e0; + const T C2 = 5.76949722146069140550e0; + const T C3 = 3.64784832476320460504e0; + const T C4 = 1.27045825245236838258e0; + const T C5 = 2.41780725177450611770e-1; + const T C6 = 2.27238449892691845833e-2; + const T C7 = 7.74545014278341407640e-4; + + const T D0 = 1.4142135623730950488016887e0; + const T D1 = 2.9036514445419946173133295e0; + const T D2 = 2.3707661626024532365971225e0; + const T D3 = 9.7547832001787427186894837e-1; + const T D4 = 2.0945065210512749128288442e-1; + const T D5 = 2.1494160384252876777097297e-2; + const T D6 = 7.7441459065157709165577218e-4; + const T D7 = 1.4859850019840355905497876e-9; + + const T E0 = 6.65790464350110377720e0; + const T E1 = 5.46378491116411436990e0; + const T E2 = 1.78482653991729133580e0; + const T E3 = 2.96560571828504891230e-1; + const T E4 = 2.65321895265761230930e-2; + const T E5 = 1.24266094738807843860e-3; + const T E6 = 2.71155556874348757815e-5; + const T E7 = 2.01033439929228813265e-7; + + const T F0 = 1.414213562373095048801689e0; + const T F1 = 8.482908416595164588112026e-1; + const T F2 = 1.936480946950659106176712e-1; + const T F3 = 2.103693768272068968719679e-2; + const T F4 = 1.112800997078859844711555e-3; + const T F5 = 2.611088405080593625138020e-5; + const T F6 = 2.010321207683943062279931e-7; + const T F7 = 2.891024605872965461538222e-15; + + T abs_x = abs(x); + + if (abs_x <= 0.85) { + T r = 0.180625 - 0.25 * x * x; + T num = + (((((((A7 * r + A6) * r + A5) * r + A4) * r + A3) * r + A2) * r + A1) * + r + + A0); + T den = + (((((((B7 * r + B6) * r + B5) * r + B4) * r + B3) * r + B2) * r + B1) * + r + + B0); + return x * num / den; + } + + T r = sqrt(LN2 - log(1.0 - abs_x)); + + T num, den; + if (r <= 5.0) { + r = r - 1.6; + num = + (((((((C7 * r + C6) * r + C5) * r + C4) * r + C3) * r + C2) * r + C1) * + r + + C0); + den = + (((((((D7 * r + D6) * r + D5) * r + D4) * r + D3) * r + D2) * r + D1) * + r + + D0); + } else { + r = r - 5.0; + num = + (((((((E7 * r + E6) * r + E5) * r + E4) * r + E3) * r + E2) * r + E1) * + r + + E0); + den = + (((((((F7 * r + F6) * r + F5) * r + F4) * r + F3) * r + F2) * r + F1) * + r + + F0); + } + + if (x < 0) { + return -num / den; + } else { + return num / den; + } +} + +template +struct TruncatedNormal { + T mean, std; + T a_normal_cdf; + T b_normal_cdf; + TruncatedNormal(T mean, T std) : mean(mean), std(std) { + auto normal_cdf = [](T x) { + return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; + }; + a_normal_cdf = normal_cdf(-2.0); + b_normal_cdf = normal_cdf(2.0); + } + + T operator()(T value) const { + auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; + return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean; + } +}; + template void TruncatedGaussianRandomKernel(const Context& dev_ctx, const std::vector& shape, @@ -42,7 +173,13 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx, TruncatedNormal truncated_normal(mean, std); int64_t size = tensor->numel(); - auto engine = paddle::framework::GetCPURandomEngine(seed); + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = dev_ctx.GetGenerator()->GetCPUEngine(); + } for (int64_t i = 0; i < size; ++i) { data[i] = truncated_normal(dist(*engine)); } diff --git a/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu index f27b32ca7b8319440b62f0d03d21129133c8470c..5b6ae9d09bff207fc56baf958fe15a5d4e9c52d2 100644 --- a/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu @@ -24,8 +24,6 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/fluid/framework/generator.h" - namespace phi { template @@ -106,8 +104,7 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx, thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); - int device_id = dev_ctx.GetPlace().GetDeviceId(); - auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); + auto gen_cuda = dev_ctx.GetGenerator(); if (gen_cuda->GetIsInitPy() && seed_flag) { auto seed_offset = gen_cuda->IncrementOffset(1); diff --git a/paddle/phi/kernels/truncated_gaussian_random_kernel.h b/paddle/phi/kernels/truncated_gaussian_random_kernel.h index 2781b79520a5d05bf957a5139c720f6639da334f..773bfc8c71eacc3cf2707dfcde246cd5ae11c1ed 100644 --- a/paddle/phi/kernels/truncated_gaussian_random_kernel.h +++ b/paddle/phi/kernels/truncated_gaussian_random_kernel.h @@ -14,149 +14,11 @@ #pragma once -#include -#include - #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/device_context.h" -#include "paddle/phi/infermeta/nullary.h" namespace phi { -// reference: https://gist.github.com/lakshayg/d80172fe5ae3c5d2c2aedb53c250320e -template -T Erfinv(T x) { - if (x < -1 || x > 1) { - return std::numeric_limits::quiet_NaN(); - } else if (x == 1.0) { - return std::numeric_limits::infinity(); - } else if (x == -1.0) { - return -std::numeric_limits::infinity(); - } - - const T LN2 = 6.931471805599453094172321214581e-1; - - const T A0 = 1.1975323115670912564578e0; - const T A1 = 4.7072688112383978012285e1; - const T A2 = 6.9706266534389598238465e2; - const T A3 = 4.8548868893843886794648e3; - const T A4 = 1.6235862515167575384252e4; - const T A5 = 2.3782041382114385731252e4; - const T A6 = 1.1819493347062294404278e4; - const T A7 = 8.8709406962545514830200e2; - - const T B0 = 1.0000000000000000000e0; - const T B1 = 4.2313330701600911252e1; - const T B2 = 6.8718700749205790830e2; - const T B3 = 5.3941960214247511077e3; - const T B4 = 2.1213794301586595867e4; - const T B5 = 3.9307895800092710610e4; - const T B6 = 2.8729085735721942674e4; - const T B7 = 5.2264952788528545610e3; - - const T C0 = 1.42343711074968357734e0; - const T C1 = 4.63033784615654529590e0; - const T C2 = 5.76949722146069140550e0; - const T C3 = 3.64784832476320460504e0; - const T C4 = 1.27045825245236838258e0; - const T C5 = 2.41780725177450611770e-1; - const T C6 = 2.27238449892691845833e-2; - const T C7 = 7.74545014278341407640e-4; - - const T D0 = 1.4142135623730950488016887e0; - const T D1 = 2.9036514445419946173133295e0; - const T D2 = 2.3707661626024532365971225e0; - const T D3 = 9.7547832001787427186894837e-1; - const T D4 = 2.0945065210512749128288442e-1; - const T D5 = 2.1494160384252876777097297e-2; - const T D6 = 7.7441459065157709165577218e-4; - const T D7 = 1.4859850019840355905497876e-9; - - const T E0 = 6.65790464350110377720e0; - const T E1 = 5.46378491116411436990e0; - const T E2 = 1.78482653991729133580e0; - const T E3 = 2.96560571828504891230e-1; - const T E4 = 2.65321895265761230930e-2; - const T E5 = 1.24266094738807843860e-3; - const T E6 = 2.71155556874348757815e-5; - const T E7 = 2.01033439929228813265e-7; - - const T F0 = 1.414213562373095048801689e0; - const T F1 = 8.482908416595164588112026e-1; - const T F2 = 1.936480946950659106176712e-1; - const T F3 = 2.103693768272068968719679e-2; - const T F4 = 1.112800997078859844711555e-3; - const T F5 = 2.611088405080593625138020e-5; - const T F6 = 2.010321207683943062279931e-7; - const T F7 = 2.891024605872965461538222e-15; - - T abs_x = abs(x); - - if (abs_x <= 0.85) { - T r = 0.180625 - 0.25 * x * x; - T num = - (((((((A7 * r + A6) * r + A5) * r + A4) * r + A3) * r + A2) * r + A1) * - r + - A0); - T den = - (((((((B7 * r + B6) * r + B5) * r + B4) * r + B3) * r + B2) * r + B1) * - r + - B0); - return x * num / den; - } - - T r = sqrt(LN2 - log(1.0 - abs_x)); - - T num, den; - if (r <= 5.0) { - r = r - 1.6; - num = - (((((((C7 * r + C6) * r + C5) * r + C4) * r + C3) * r + C2) * r + C1) * - r + - C0); - den = - (((((((D7 * r + D6) * r + D5) * r + D4) * r + D3) * r + D2) * r + D1) * - r + - D0); - } else { - r = r - 5.0; - num = - (((((((E7 * r + E6) * r + E5) * r + E4) * r + E3) * r + E2) * r + E1) * - r + - E0); - den = - (((((((F7 * r + F6) * r + F5) * r + F4) * r + F3) * r + F2) * r + F1) * - r + - F0); - } - - if (x < 0) { - return -num / den; - } else { - return num / den; - } -} - -template -struct TruncatedNormal { - T mean, std; - T a_normal_cdf; - T b_normal_cdf; - TruncatedNormal(T mean, T std) : mean(mean), std(std) { - auto normal_cdf = [](T x) { - return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; - }; - a_normal_cdf = normal_cdf(-2.0); - b_normal_cdf = normal_cdf(2.0); - } - - T operator()(T value) const { - auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; - return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean; - } -}; - template void TruncatedGaussianRandomKernel(const Context& dev_ctx, const std::vector& shape, diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index bdc97eca0d84f0f5d67aa23b1fae749ba0179818..37eff6d132d03bc634f9d0ae3fdb62d118d2820e 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -17,7 +17,7 @@ from __future__ import print_function import math from . import framework from . import core -from .framework import _non_static_mode, default_main_program +from .framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, default_main_program, _current_expected_place import numpy as np from .core import VarDesc from . import unique_name @@ -417,7 +417,18 @@ class TruncatedNormalInitializer(Initializer): out_dtype = var.dtype out_var = var - if framework._non_static_mode(): + if in_dygraph_mode(): + out_var = _C_ops.final_state_truncated_gaussian_random( + var.shape, self._mean, self._std_dev, self._seed, out_dtype, + _current_expected_place()) + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + var_tmp = _C_ops.final_state_cast(out_var, var.dtype) + var_tmp._share_underline_tensor_to(var) + else: + out_var._share_underline_tensor_to(var) + return None + + if _in_legacy_dygraph(): out_var = _C_ops.truncated_gaussian_random( 'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean, 'std', self._std_dev, 'seed', self._seed) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index 4e4fe69d914fadd394228740fd4866610e71b6a0..44263b89e161681a9043eb9d454ebf485a0122cf 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -113,6 +113,7 @@ class TestMKLDNNSwishDim2(TestSwish): super(TestMKLDNNSwishDim2, self).setUp() self.attrs["use_mkldnn"] = True + self.check_eager = False def init_dtype(self): self.dtype = np.float32 @@ -284,6 +285,7 @@ class TestMKLDNNSwishDim4(TestSwish): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} self.attrs = {"use_mkldnn": True, "beta": beta} + self.check_eager = False def init_dtype(self): self.dtype = np.float32 diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 04e37a9b0379aa35743242721092d26ad2f334ef..c14bb4586f612d2201d784594c288e5d3f8cd555 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -25,6 +25,7 @@ import paddle.nn.functional as F import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import _test_eager_guard paddle.enable_static() @@ -2928,7 +2929,9 @@ def ref_swish(x): class TestSwish(TestActivation): def setUp(self): self.op_type = "swish" + self.python_api = paddle.nn.functional.swish self.init_dtype() + self.check_eager = True np.random.seed(1024) x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) @@ -2940,7 +2943,10 @@ class TestSwish(TestActivation): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + check_eager = False + if hasattr(self, 'check_eager'): + check_eager = self.check_eager + self.check_grad(['X'], 'Out', check_eager=check_eager) class TestSwishAPI(unittest.TestCase): @@ -2975,6 +2981,10 @@ class TestSwishAPI(unittest.TestCase): self.assertEqual(np.allclose(out_ref, r.numpy()), True) paddle.enable_static() + def test_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_dygraph_api() + def test_fluid_api(self): paddle.enable_static() with fluid.program_guard(fluid.Program()): diff --git a/python/paddle/fluid/tests/unittests/test_truncated_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_truncated_gaussian_random_op.py index 4abeae77d26e8def85596aefc6c2f89cd4e4d6f0..fe28e0c9638b4bf94c48f2f2150087eb8ab26590 100644 --- a/python/paddle/fluid/tests/unittests/test_truncated_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_truncated_gaussian_random_op.py @@ -17,10 +17,13 @@ from __future__ import print_function import unittest import numpy +import paddle import paddle.fluid as fluid import paddle.fluid.core as core +from op_test import OpTest from paddle.fluid.op import Operator from paddle.fluid.executor import Executor +from paddle.fluid.framework import _test_eager_guard class TestTrunctedGaussianRandomOp(unittest.TestCase): @@ -33,15 +36,16 @@ class TestTrunctedGaussianRandomOp(unittest.TestCase): "std": 1., "seed": 10, } - self.outputs = ["Out"] def test_cpu(self): self.gaussian_random_test(place=fluid.CPUPlace()) + self.gaussian_random_test_eager(place=fluid.CPUPlace()) def test_gpu(self): if core.is_compiled_with_cuda(): self.gaussian_random_test(place=fluid.CUDAPlace(0)) + self.gaussian_random_test_eager(place=fluid.CUDAPlace(0)) def gaussian_random_test(self, place): @@ -64,6 +68,17 @@ class TestTrunctedGaussianRandomOp(unittest.TestCase): self.assertAlmostEqual(numpy.mean(tensor), .0, delta=0.1) self.assertAlmostEqual(numpy.var(tensor), 0.773, delta=0.1) + # TruncatedNormal.__call__ has no return value, so here call _C_ops api + # directly + def gaussian_random_test_eager(self, place): + with fluid.dygraph.guard(place): + with _test_eager_guard(): + out = paddle._C_ops.final_state_truncated_gaussian_random( + self.attrs["shape"], self.attrs["mean"], self.attrs["std"], + self.attrs["seed"], core.VarDesc.VarType.FP32, place) + self.assertAlmostEqual(numpy.mean(out.numpy()), .0, delta=0.1) + self.assertAlmostEqual(numpy.var(out.numpy()), 0.773, delta=0.1) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unbind_op.py b/python/paddle/fluid/tests/unittests/test_unbind_op.py index e16fb6ddaacd71f1f76a72bd301fe06ce8059214..43f2f3526ac0fcc87c3e483f63534ac5e052d249 100644 --- a/python/paddle/fluid/tests/unittests/test_unbind_op.py +++ b/python/paddle/fluid/tests/unittests/test_unbind_op.py @@ -17,9 +17,11 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest, convert_float_to_uint16 +import paddle import paddle.fluid as fluid import paddle.tensor as tensor from paddle.fluid import compiler, Program, program_guard, core +from paddle.fluid.framework import _test_eager_guard class TestUnbind(unittest.TestCase): @@ -39,6 +41,25 @@ class TestUnbind(unittest.TestCase): assert np.array_equal(res_1, input_1[0, 0:100]) assert np.array_equal(res_2, input_1[1, 0:100]) + def test_unbind_dygraph(self): + with fluid.dygraph.guard(): + np_x = np.random.random([2, 3]).astype("float32") + x = paddle.to_tensor(np_x) + x.stop_gradient = False + [res_1, res_2] = paddle.unbind(x, 0) + self.assertTrue(np.array_equal(res_1, np_x[0, 0:100])) + self.assertTrue(np.array_equal(res_2, np_x[1, 0:100])) + + out = paddle.add_n([res_1, res_2]) + + np_grad = np.ones(x.shape, np.float32) + out.backward() + self.assertTrue(np.array_equal(x.grad.numpy(), np_grad)) + + def test_unbind_dygraph_final_state(self): + with _test_eager_guard(): + self.test_unbind_dygraph() + class TestLayersUnbind(unittest.TestCase): def test_layers_unbind(self): @@ -157,6 +178,7 @@ class TestUnbindOp4(TestUnbindOp): class TestUnbindBF16Op(OpTest): def setUp(self): self._set_op_type() + self.python_api = paddle.unbind self.dtype = self.get_dtype() self.axis = 0 self.num = 3 diff --git a/python/paddle/fluid/tests/unittests/test_unique.py b/python/paddle/fluid/tests/unittests/test_unique.py index a4bef436e13755acb14b7eb8226b19774453d528..71dce5cc463cf5c23dc2401911ec7cc03f1c8d59 100644 --- a/python/paddle/fluid/tests/unittests/test_unique.py +++ b/python/paddle/fluid/tests/unittests/test_unique.py @@ -21,6 +21,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.op import Operator +from paddle.fluid.framework import _test_eager_guard class TestUniqueOp(OpTest): @@ -251,6 +252,12 @@ class TestUniqueAPI(unittest.TestCase): self.assertTrue((counts.numpy() == np_counts).all(), True) paddle.enable_static() + def test_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_dygraph_api_out() + self.test_dygraph_api_attr() + self.test_dygraph_attr_dtype() + def test_static_graph(self): with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 3bdda982ff4f1f17d025f4fcfafd89fd839e974d..ce82b10701b3c8f52b332d578c3a5b585eb75cc9 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1175,8 +1175,9 @@ def swish(x, name=None): x = paddle.to_tensor(np.array([-2., 0., 1.])) out = F.swish(x) # [-0.238406, 0., 0.731059] """ - - if in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_swish(x, 1.0) + if _in_legacy_dygraph(): return _C_ops.swish(x, 'beta', 1.0) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish') diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f1e2938b205c702ae1a420bde37086419ed3f33d..0f90cf6950aff7c300b57b1080eaa3d8419d8ee0 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1211,11 +1211,16 @@ def unique(x, else: axis = [axis] attr_dtype = convert_np_dtype_to_dtype_(dtype) - if paddle.in_dynamic_mode(): - out, inverse, indices, counts = _C_ops.unique( - x, 'dtype', attr_dtype, 'return_index', return_index, - 'return_inverse', return_inverse, 'return_counts', return_counts, - 'axis', axis, "is_sorted", True) + if _non_static_mode(): + if in_dygraph_mode(): + out, indices, inverse, counts = _C_ops.final_state_unique( + x, return_index, return_inverse, return_counts, axis, + attr_dtype) + if _in_legacy_dygraph(): + out, inverse, indices, counts = _C_ops.unique( + x, 'dtype', attr_dtype, 'return_index', return_index, + 'return_inverse', return_inverse, 'return_counts', + return_counts, 'axis', axis, "is_sorted", True) outs = [out] if return_index: outs.append(indices) @@ -1464,6 +1469,9 @@ def unbind(input, axis=0): # x3.shape [3, 5] """ + if in_dygraph_mode(): + return _C_ops.final_state_unbind(input, axis) + if not isinstance(axis, (int)): raise TypeError("The type of 'axis' must be int, but received %s." % (type(axis))) @@ -1472,7 +1480,7 @@ def unbind(input, axis=0): input_shape = input.shape axis_ = axis if axis >= 0 else len(input_shape) + axis num = input_shape[axis_] - if paddle.in_dynamic_mode(): + if _in_legacy_dygraph(): return _C_ops.unbind(input, num, 'axis', axis) helper = LayerHelper("unbind", **locals()) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index acaab007e03fb6a74857bb3305f091dd1146a66a..fa5706836d80504b1eb7ca792f22b9f52b3c2119 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1744,6 +1744,17 @@ data_type : x backward : sum_grad +# The python API paddle.nn.functional.swish has no `bete` argument, it may be removed later +- api : swish + args : (Tensor x, float beta=1.0) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : swish + backward : swish_grad + # take_along_axis - api : take_along_axis args : (Tensor x, Tensor index, int axis) @@ -1861,6 +1872,25 @@ func : trunc backward : trunc_grad +# python API: paddle.nn.initializer.TruncatedNormal +- api : truncated_gaussian_random + args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={}) + output : Tensor + infer_meta : + func : TruncatedGaussianRandomInferMeta + param : [shape, mean, std, seed, dtype] + kernel : + func : truncated_gaussian_random + param : [shape, mean, std, seed, dtype] + backend : place + data_type : dtype + +- api : unbind + args : (Tensor input, int axis) + output : Tensor[] + invoke : unbind_impl(input, axis) + backward : unbind_grad + # unfold - api : unfold args : (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) @@ -1871,6 +1901,16 @@ func : unfold backward : unfold_grad +# The `axis` argument of Python API paddle.unique is not vector +- api : unique + args : (Tensor x, bool return_index, bool return_inverse, bool return_counts, int[] axis, DataType dtype=DataType::INT64) + output : Tensor(out), Tensor(indices), Tensor(inverse), Tensor(counts) + infer_meta : + func : UniqueInferMeta + kernel : + func : unique + data_type : x + - api : unsqueeze args : (Tensor x, IntArray axes) output : Tensor(xshape), Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 4cb411634a0adc4ee872eb6959ca6999205c769b..708c754f78269ac3ad8fe591ecf44438cb61ec08 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1317,6 +1317,16 @@ kernel : func : sum_grad +- backward_api : swish_grad + forward : swish (Tensor x, float beta=1.0) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float bete=1.0) + output : Tensor(x_grad) + infer_meta : + func : GeneralUnaryGradInferMeta + param : [x] + kernel : + func : swish_grad + - backward_api : take_along_axis_grad forward : take_along_axis (Tensor x, Tensor index, int axis) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad, int axis) @@ -1429,6 +1439,12 @@ kernel : func : trunc_grad +- backward_api : unbind_grad + forward : unbind (Tensor input, int axis) -> Tensor[](out) + args : (Tensor[] out_grad, int axis) + output : Tensor(input_grad) + invoke : stack(out_grad, axis) + - backward_api : unfold_grad forward : unfold (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out) args : (Tensor x, Tensor out_grad, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)