未验证 提交 14e45f6b 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operator] Overload Tensor Operator (#50098)

* init commit

* fix tensor operator*

* fix compile bug

* bug reproduce

* update commit

* polish codes

* fix compile bug

* test begin

* test begin

* compile finish

* restore origin composite_backward_api

* pass local CI

* fix merge error

* fix merge error

* change py_test from GPU->CPU, test custom op

* polish codes, modify prim unittest

* modify prim unittest

* determine phi_tensor_operants location

* polish codes

* add header file

* solve windows unresolved symbol

* fix some CI error

* add overload defination

* fix CI inference and Windows

* polish codes according to reviewers' opinion

* polish codes according to reviewers' opinion
上级 fd0d4fa4
...@@ -191,6 +191,7 @@ paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallV ...@@ -191,6 +191,7 @@ paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallV
FORWARD_FUNCTION_TEMPLATE = """ FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{ {} {}({}) {{
FLAGS_tensor_operants_mode = "eager";
VLOG(3) << \"Running AD API: \" << \"{}\"; VLOG(3) << \"Running AD API: \" << \"{}\";
// Dygraph Record Event // Dygraph Record Event
{} {}
...@@ -246,6 +247,7 @@ BEFORE_LOG_PRINT_TEMPLATE = """ ...@@ -246,6 +247,7 @@ BEFORE_LOG_PRINT_TEMPLATE = """
FORWARD_ONLY_FUNCTION_TEMPLATE = """ FORWARD_ONLY_FUNCTION_TEMPLATE = """
{} {}({}) {{ {} {}({}) {{
FLAGS_tensor_operants_mode = "eager";
VLOG(3) << \"Running AD API: \" << \"{}\"; VLOG(3) << \"Running AD API: \" << \"{}\";
// Dygraph Record Event // Dygraph Record Event
{} {}
...@@ -364,6 +366,7 @@ FORWARD_CC_FILE_TEMPLATE = """ ...@@ -364,6 +366,7 @@ FORWARD_CC_FILE_TEMPLATE = """
#include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
DECLARE_string(tensor_operants_mode);
{} {}
{} {}
""" """
......
...@@ -1203,7 +1203,9 @@ cc_library( ...@@ -1203,7 +1203,9 @@ cc_library(
string_helper string_helper
phi_tensor phi_tensor
op_meta_info op_meta_info
phi_api) phi_api
phi_tensor_operants
operants_manager)
set(FLUID_FRAMEWORK_MODULES set(FLUID_FRAMEWORK_MODULES
proto_desc proto_desc
......
...@@ -45,6 +45,12 @@ limitations under the License. */ ...@@ -45,6 +45,12 @@ limitations under the License. */
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
#endif #endif
#include "gflags/gflags.h"
#include "paddle/phi/api/include/tensor_operants.h"
#include "paddle/phi/core/operants_manager.h"
DECLARE_string(tensor_operants_mode);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -270,6 +276,15 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -270,6 +276,15 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
try { try {
VLOG(3) << "Custom Operator: Run ComputeFunc."; VLOG(3) << "Custom Operator: Run ComputeFunc.";
FLAGS_tensor_operants_mode = "phi";
if (paddle::operants::OperantsManager::Instance().phi_operants.get() ==
nullptr) {
paddle::operants::OperantsManager::Instance().phi_operants.reset(
new paddle::operants::PhiTensorOperants());
VLOG(4) << "Initialize phi tensor operants successfully";
}
func(&kernel_ctx); func(&kernel_ctx);
// sync output tensor data into original output // sync output tensor data into original output
......
...@@ -29,7 +29,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { ...@@ -29,7 +29,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return; if (!grad_x) return;
auto tmp = pow<T>(out, 2.0); auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true); tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp); auto grad_x_tmp = grad_out * tmp;
set_output<T>(grad_x_tmp, grad_x); set_output<T>(grad_x_tmp, grad_x);
} }
...@@ -172,7 +172,7 @@ void divide_grad(const Tensor& x, ...@@ -172,7 +172,7 @@ void divide_grad(const Tensor& x,
auto tmp0 = pow<T>(y, 2.0); auto tmp0 = pow<T>(y, 2.0);
auto tmp1 = divide<T>(x, tmp0); auto tmp1 = divide<T>(x, tmp0);
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true); auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
auto dy_res = multiply<T>(tmp2, out_grad); auto dy_res = tmp2 * out_grad;
if (x.dims() != y.dims()) { if (x.dims() != y.dims()) {
// Maybe need reduce here // Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
...@@ -192,7 +192,7 @@ void divide_grad(const Tensor& x, ...@@ -192,7 +192,7 @@ void divide_grad(const Tensor& x,
// dx = (1/y) * dout // dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype()); auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y); auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad); auto dx_res = tmp0 * out_grad;
if (y.dims() != x.dims()) { if (y.dims() != x.dims()) {
// Maybe need reduce here // Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
...@@ -216,7 +216,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { ...@@ -216,7 +216,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5); auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out); auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp); auto x_grad_tmp = out_grad * tmp;
set_output<T>(x_grad_tmp, x_grad); set_output<T>(x_grad_tmp, x_grad);
} }
} }
...@@ -229,7 +229,7 @@ void multiply_grad(const Tensor& x, ...@@ -229,7 +229,7 @@ void multiply_grad(const Tensor& x,
Tensor* x_grad, Tensor* x_grad,
Tensor* y_grad) { Tensor* y_grad) {
if (x_grad) { if (x_grad) {
auto x_grad_unreduce = multiply<T>(out_grad, y); auto x_grad_unreduce = out_grad * y;
if (x_grad_unreduce.dims() != x.dims()) { if (x_grad_unreduce.dims() != x.dims()) {
auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims()); auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims());
if (!axes.size()) { if (!axes.size()) {
...@@ -249,7 +249,7 @@ void multiply_grad(const Tensor& x, ...@@ -249,7 +249,7 @@ void multiply_grad(const Tensor& x,
} }
} }
if (y_grad) { if (y_grad) {
auto y_grad_unreduce = multiply<T>(out_grad, x); auto y_grad_unreduce = out_grad * x;
if (y_grad_unreduce.dims() != y.dims()) { if (y_grad_unreduce.dims() != y.dims()) {
auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims()); auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims());
if (!axes.size()) { if (!axes.size()) {
...@@ -297,7 +297,7 @@ void expand_grad(const Tensor& x, ...@@ -297,7 +297,7 @@ void expand_grad(const Tensor& x,
template <typename T> template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
set_output<T>(multiply<T>(out_grad, out), x_grad); set_output<T>(out_grad * out, x_grad);
} }
} }
......
...@@ -33,9 +33,16 @@ cc_test_old( ...@@ -33,9 +33,16 @@ cc_test_old(
activation_op activation_op
phi_api phi_api
phi_dygraph_api phi_dygraph_api
static_global_utils) static_global_utils
static_tensor_operants
operants_manager)
if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
init_env_utils
SRCS init_env_utils.cc
DEPS operants_manager eager_tensor_operants static_tensor_operants)
cc_test_old( cc_test_old(
test_comp_eager test_comp_eager
SRCS SRCS
...@@ -44,5 +51,6 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER)) ...@@ -44,5 +51,6 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
${prim_eager_deps} ${prim_eager_deps}
${prim_generated_deps} ${prim_generated_deps}
prim_utils prim_utils
static_global_utils) static_global_utils
init_env_utils)
endif() endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/prim/tests/init_env_utils.h"
#include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h"
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"
#include "paddle/phi/core/operants_manager.h"
namespace paddle {
namespace prim {
void InitTensorOperants() {
paddle::operants::OperantsManager::Instance().eager_operants.reset(
new paddle::operants::EagerTensorOperants());
paddle::operants::OperantsManager::Instance().static_operants.reset(
new paddle::operants::StaticTensorOperants());
}
} // namespace prim
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace prim {
void InitTensorOperants();
} // namespace prim
} // namespace paddle
...@@ -14,17 +14,21 @@ ...@@ -14,17 +14,21 @@
#include <sstream> #include <sstream>
#include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/utils/hook_utils.h" #include "paddle/fluid/eager/api/utils/hook_utils.h"
#include "paddle/fluid/eager/backward.h" #include "paddle/fluid/eager/backward.h"
#include "paddle/fluid/eager/tests/test_utils.h" #include "paddle/fluid/eager/tests/test_utils.h"
#include "paddle/fluid/prim/tests/init_env_utils.h"
#include "paddle/fluid/prim/utils/utils.h" #include "paddle/fluid/prim/utils/utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
DECLARE_string(tensor_operants_mode);
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
...@@ -46,6 +50,8 @@ namespace prim { ...@@ -46,6 +50,8 @@ namespace prim {
TEST(EagerPrim, TanhBackwardTest) { TEST(EagerPrim, TanhBackwardTest) {
// 1. Initialized // 1. Initialized
eager_test::InitEnv(paddle::platform::CPUPlace()); eager_test::InitEnv(paddle::platform::CPUPlace());
FLAGS_tensor_operants_mode = "eager";
paddle::prim::InitTensorOperants();
// 2. pre // 2. pre
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
paddle::experimental::Tensor tensor0 = paddle::experimental::Tensor tensor0 =
......
...@@ -21,11 +21,15 @@ ...@@ -21,11 +21,15 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"
#include "paddle/fluid/prim/utils/utils.h" #include "paddle/fluid/prim/utils/utils.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/operants_manager.h"
DECLARE_bool(prim_enabled); DECLARE_bool(prim_enabled);
DECLARE_string(tensor_operants_mode);
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
...@@ -142,6 +146,11 @@ class TestCompositeGradMaker : public CompositeGradOpMakerBase { ...@@ -142,6 +146,11 @@ class TestCompositeGradMaker : public CompositeGradOpMakerBase {
}; };
TEST(StaticPrim, TanhBackwardComposite) { TEST(StaticPrim, TanhBackwardComposite) {
// Initialized environment
FLAGS_tensor_operants_mode = "static";
paddle::operants::OperantsManager::Instance().static_operants.reset(
new paddle::operants::StaticTensorOperants());
TestBaseProgram base_program = TestBaseProgram(); TestBaseProgram base_program = TestBaseProgram();
auto* target_block = base_program.GetBlock(0); auto* target_block = base_program.GetBlock(0);
// Prepare for forward tanh // Prepare for forward tanh
...@@ -223,6 +232,11 @@ TEST(StaticPrim, TanhBackwardComposite) { ...@@ -223,6 +232,11 @@ TEST(StaticPrim, TanhBackwardComposite) {
} }
TEST(StaticCompositeGradMaker, TestMutiInputMethod) { TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
// Initialized environment
FLAGS_tensor_operants_mode = "static";
paddle::operants::OperantsManager::Instance().static_operants.reset(
new paddle::operants::StaticTensorOperants());
TestBaseProgram base_program = TestBaseProgram(); TestBaseProgram base_program = TestBaseProgram();
auto* target_block = base_program.GetBlock(0); auto* target_block = base_program.GetBlock(0);
std::vector<int64_t> shape = {2, 2}; std::vector<int64_t> shape = {2, 2};
...@@ -285,6 +299,11 @@ TEST(StaticCompositeGradMaker, TestMutiInputMethod) { ...@@ -285,6 +299,11 @@ TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
} }
TEST(StaticCompositeGradMaker, TestMutiOutputMethod) { TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
// Initialized environment
FLAGS_tensor_operants_mode = "static";
paddle::operants::OperantsManager::Instance().static_operants.reset(
new paddle::operants::StaticTensorOperants());
TestBaseProgram base_program = TestBaseProgram(); TestBaseProgram base_program = TestBaseProgram();
auto* target_block = base_program.GetBlock(0); auto* target_block = base_program.GetBlock(0);
std::vector<int64_t> shape = {4, 2}; std::vector<int64_t> shape = {4, 2};
......
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
eager_tensor_operants
SRCS eager_tensor_operants.cc
DEPS final_dygraph_function)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h"
#include "glog/logging.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
namespace paddle {
namespace operants {
Tensor EagerTensorOperants::multiply(const Tensor& x, const Tensor& y) {
return ::multiply_ad_func(x, y);
}
} // namespace operants
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/operants_base.h"
namespace paddle {
namespace operants {
class EagerTensorOperants : public TensorOperantsBase {
public:
EagerTensorOperants() = default;
Tensor multiply(const Tensor& x, const Tensor& y) override;
private:
DISABLE_COPY_AND_ASSIGN(EagerTensorOperants);
};
} // namespace operants
} // namespace paddle
...@@ -2,3 +2,8 @@ cc_library( ...@@ -2,3 +2,8 @@ cc_library(
static_global_utils static_global_utils
SRCS static_global_utils.cc SRCS static_global_utils.cc
DEPS proto_desc) DEPS proto_desc)
cc_library(
static_tensor_operants
SRCS static_tensor_operants.cc
DEPS static_prim_api)
...@@ -29,6 +29,10 @@ ...@@ -29,6 +29,10 @@
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
DECLARE_string(tensor_operants_mode);
namespace paddle { namespace paddle {
namespace prim { namespace prim {
...@@ -59,6 +63,7 @@ class CompositeGradOpMakerBase { ...@@ -59,6 +63,7 @@ class CompositeGradOpMakerBase {
// TODO(jiabin): This should always execute by one thread... // TODO(jiabin): This should always execute by one thread...
VLOG(6) << "Constructing Composite Grad func for " << fwd_op_.Type() VLOG(6) << "Constructing Composite Grad func for " << fwd_op_.Type()
<< "_grad "; << "_grad ";
FLAGS_tensor_operants_mode = "static";
StaticCompositeContext::Instance().SetBlock( StaticCompositeContext::Instance().SetBlock(
acting_program_.MutableBlock(0)); acting_program_.MutableBlock(0));
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"
#include "glog/logging.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
namespace operants {
using DescTensor = paddle::prim::DescTensor;
Tensor StaticTensorOperants::multiply(const Tensor& x, const Tensor& y) {
return paddle::prim::multiply<DescTensor>(x, y);
}
} // namespace operants
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/operants_base.h"
namespace paddle {
namespace operants {
class StaticTensorOperants : public TensorOperantsBase {
public:
StaticTensorOperants() = default;
Tensor multiply(const Tensor& x, const Tensor& y) override;
private:
DISABLE_COPY_AND_ASSIGN(StaticTensorOperants);
};
} // namespace operants
} // namespace paddle
...@@ -497,6 +497,10 @@ if(WITH_PYTHON) ...@@ -497,6 +497,10 @@ if(WITH_PYTHON)
list(APPEND PYBIND_DEPS python) list(APPEND PYBIND_DEPS python)
list(APPEND PYBIND_DEPS custom_operator) list(APPEND PYBIND_DEPS custom_operator)
list(APPEND PYBIND_DEPS custom_operator_node) list(APPEND PYBIND_DEPS custom_operator_node)
list(APPEND PYBIND_DEPS operants_manager)
list(APPEND PYBIND_DEPS eager_tensor_operants)
list(APPEND PYBIND_DEPS static_tensor_operants)
list(APPEND PYBIND_DEPS phi_tensor_operants)
endif() endif()
# On Linux, cc_library(paddle SHARED ..) will generate the libpaddle.so, # On Linux, cc_library(paddle SHARED ..) will generate the libpaddle.so,
......
...@@ -35,6 +35,8 @@ typedef SSIZE_T ssize_t; ...@@ -35,6 +35,8 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h"
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"
#include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
...@@ -54,6 +56,12 @@ typedef SSIZE_T ssize_t; ...@@ -54,6 +56,12 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/pybind/cuda_streams_py.h" #include "paddle/fluid/pybind/cuda_streams_py.h"
#endif #endif
#include "gflags/gflags.h"
#include "paddle/phi/api/include/tensor_operants.h"
#include "paddle/phi/core/operants_manager.h"
DECLARE_string(tensor_operants_mode);
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -487,10 +495,32 @@ static PyObject* eager_api_jit_function_call(PyObject* self, ...@@ -487,10 +495,32 @@ static PyObject* eager_api_jit_function_call(PyObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* eager_api_init_eager_and_static_tensor_operants(
PyObject* self, PyObject* args, PyObject* kwargs) {
EAGER_TRY
paddle::operants::OperantsManager::Instance().eager_operants.reset(
new paddle::operants::EagerTensorOperants());
paddle::operants::OperantsManager::Instance().static_operants.reset(
new paddle::operants::StaticTensorOperants());
VLOG(4) << "Initialize eager and static tensor operants successfully";
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* eager_api_run_custom_op(PyObject* self, static PyObject* eager_api_run_custom_op(PyObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
FLAGS_tensor_operants_mode = "phi";
if (paddle::operants::OperantsManager::Instance().phi_operants.get() ==
nullptr) {
paddle::operants::OperantsManager::Instance().phi_operants.reset(
new paddle::operants::PhiTensorOperants());
VLOG(4) << "Initialize phi tensor operants successfully";
}
paddle::CustomOpKernelContext ctx = paddle::CustomOpKernelContext ctx =
CastPyArg2CustomOpKernelContext(PyTuple_GET_ITEM(args, 0), 0); CastPyArg2CustomOpKernelContext(PyTuple_GET_ITEM(args, 0), 0);
std::string op_type = CastPyArg2AttrString(PyTuple_GET_ITEM(args, 1), 1); std::string op_type = CastPyArg2AttrString(PyTuple_GET_ITEM(args, 1), 1);
...@@ -1090,6 +1120,11 @@ PyMethodDef variable_functions[] = { ...@@ -1090,6 +1120,11 @@ PyMethodDef variable_functions[] = {
(PyCFunction)(void (*)(void))eager_api_run_custom_op, (PyCFunction)(void (*)(void))eager_api_run_custom_op,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"_init_eager_and_static_tensor_operants",
(PyCFunction)(void (*)(
void))eager_api_init_eager_and_static_tensor_operants,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"tensor_copy", {"tensor_copy",
(PyCFunction)(void (*)(void))eager_api_tensor_copy, (PyCFunction)(void (*)(void))eager_api_tensor_copy,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
......
...@@ -615,5 +615,7 @@ class PADDLE_API Tensor final { ...@@ -615,5 +615,7 @@ class PADDLE_API Tensor final {
std::string name_{""}; std::string name_{""};
}; };
PADDLE_API Tensor operator*(const Tensor& x, const Tensor& y);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/operants_base.h"
namespace paddle {
namespace operants {
class PhiTensorOperants : public TensorOperantsBase {
public:
PhiTensorOperants() = default;
Tensor multiply(const Tensor& x, const Tensor& y) override;
private:
DISABLE_COPY_AND_ASSIGN(PhiTensorOperants);
};
} // namespace operants
} // namespace paddle
...@@ -4,17 +4,20 @@ if(WITH_GPU) ...@@ -4,17 +4,20 @@ if(WITH_GPU)
nv_library( nv_library(
phi_tensor_raw phi_tensor_raw
SRCS tensor.cc SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool) DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool
operants_manager)
elseif(WITH_ROCM) elseif(WITH_ROCM)
hip_library( hip_library(
phi_tensor_raw phi_tensor_raw
SRCS tensor.cc SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool) DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool
operants_manager)
else() else()
cc_library( cc_library(
phi_tensor_raw phi_tensor_raw
SRCS tensor.cc SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool) DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool
operants_manager)
endif() endif()
set(api_gen_base ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/api_base.py) set(api_gen_base ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/api_base.py)
...@@ -308,3 +311,8 @@ cc_library( ...@@ -308,3 +311,8 @@ cc_library(
api_int_array api_int_array
SRCS int_array.cc SRCS int_array.cc
DEPS tensor_copy) DEPS tensor_copy)
cc_library(
phi_tensor_operants
SRCS tensor_operants.cc
DEPS phi_function_api)
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/operants_manager.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
...@@ -413,5 +414,9 @@ void Tensor::reset_inplace_version(bool set_to_zero) { ...@@ -413,5 +414,9 @@ void Tensor::reset_inplace_version(bool set_to_zero) {
} }
} }
PADDLE_API Tensor operator*(const Tensor &x, const Tensor &y) {
return paddle::operants::OperantsManager::Instance().multiply(x, y);
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/api/include/tensor_operants.h"
#include "glog/logging.h"
#include "paddle/phi/api/include/api.h"
namespace paddle {
namespace operants {
Tensor PhiTensorOperants::multiply(const Tensor& x, const Tensor& y) {
return paddle::experimental::multiply(x, y);
}
} // namespace operants
} // namespace paddle
...@@ -114,6 +114,11 @@ cc_library( ...@@ -114,6 +114,11 @@ cc_library(
SRCS custom_kernel.cc SRCS custom_kernel.cc
DEPS kernel_factory) DEPS kernel_factory)
cc_library(
operants_manager
SRCS operants_manager.cc
DEPS flags)
cc_library( cc_library(
mixed_vector mixed_vector
SRCS mixed_vector.cc SRCS mixed_vector.cc
......
...@@ -1206,3 +1206,19 @@ PADDLE_DEFINE_EXPORTED_bool(trt_ibuilder_cache, ...@@ -1206,3 +1206,19 @@ PADDLE_DEFINE_EXPORTED_bool(trt_ibuilder_cache,
PADDLE_DEFINE_EXPORTED_bool(use_shm_cache, PADDLE_DEFINE_EXPORTED_bool(use_shm_cache,
false, false,
"Use shm cache in mmap_allocator."); "Use shm cache in mmap_allocator.");
/**
* Tensor operants related FLAG
* Name: tensor_operants_mode
* Since Version: 2.5.0
* Value Range: string, {eager, phi, static}
* default=eager
* Example:
* Note: For switching tensor operants mode of PaddlePaddle.
* - eager mode: tensor operants with dygraph autograd;
* - phi mode: tensor operants with only phi forward API;
* - static mode: tensor operants within static graph.
*/
PADDLE_DEFINE_EXPORTED_string(tensor_operants_mode,
"eager",
"Tensor operants mode");
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace operants {
using Tensor = paddle::experimental::Tensor;
class TensorOperantsBase {
public:
virtual ~TensorOperantsBase() = default;
virtual Tensor multiply(const Tensor& x, const Tensor& y) = 0;
};
} // namespace operants
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/operants_manager.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
DECLARE_string(tensor_operants_mode);
namespace paddle {
namespace operants {
OperantsManager& OperantsManager::Instance() {
static OperantsManager g_op_manager;
return g_op_manager;
}
Tensor OperantsManager::multiply(const Tensor& x, const Tensor& y) {
if (FLAGS_tensor_operants_mode == "eager") {
PADDLE_ENFORCE_NE(
this->eager_operants.get(),
nullptr,
phi::errors::Unavailable("The eager_operants pointer of "
"OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches eager mode";
return this->eager_operants->multiply(x, y);
} else if (FLAGS_tensor_operants_mode == "static") {
PADDLE_ENFORCE_NE(
this->static_operants.get(),
nullptr,
phi::errors::Unavailable("The static_operants pointer of "
"OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches static mode";
return this->static_operants->multiply(x, y);
} else if (FLAGS_tensor_operants_mode == "phi") {
PADDLE_ENFORCE_NE(
this->phi_operants.get(),
nullptr,
phi::errors::Unavailable(
"The phi_operants pointer of OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches phi mode";
return this->phi_operants->multiply(x, y);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"FLAGS_tensor_operants_mode is not nitialized, please set "
"FLAGS_tensor_operants_mode first, which currently supports eager, "
"phi, and static mode"));
}
}
} // namespace operants
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/operants_base.h"
namespace paddle {
namespace operants {
using Tensor = paddle::experimental::Tensor;
/**
* [ Why need OperantsManager? ]
*
* Ideally, overloading tensor operators should call Tensor API directly.
* However, we faced two problems:
*
* 1. Support multiple modes: Tensor operator overloading needs to support
* [static mode / autograd mode / custom operator mode] at the same time.
*
* 2. Decouple phi and fluid: Tensor belongs to the phi library, but it relies
* upon functions in fluid when overloading Tensor operators.
*
* We design OperantsManager to solve these two problems:
*
* 1. use `FLAGS_tensor_operants_mode` to handle overloading mode, set this flag
* at the entry point of each mode:
*
* - FLAGS_tensor_operants_mode = "static": at the construction function of
* `CompositeGradOpMakerBase`.
* - FLAGS_tensor_operants_mode = "eager": at the beginning of dygraph_function.
* - FLAGS_tensor_operants_mode = "phi": at the beginning of the
* `eager_api_run_custom_op` function in eager mode and at the beginning of
* calling kernels in static mode.
*
* In order to guarantee the performance, OperantsManager holds three pointers
* to identify each mode respectively.
*
* 2. Decouple phi with the help of the polymorphism mechanism,
* TensorOperantsBase derives three child classes: PhiTensorOperants,
* EagerTensorOperants, and StaticTensorOperants. We set eager and static tensor
* operants at the fluid library and set phi operants at the phi library.
*
*/
class OperantsManager {
public:
static OperantsManager& Instance();
Tensor multiply(const Tensor& x, const Tensor& y);
public:
std::unique_ptr<TensorOperantsBase> eager_operants{nullptr};
std::unique_ptr<TensorOperantsBase> static_operants{nullptr};
std::unique_ptr<TensorOperantsBase> phi_operants{nullptr};
private:
OperantsManager() = default;
DISABLE_COPY_AND_ASSIGN(OperantsManager);
};
} // namespace operants
} // namespace paddle
...@@ -239,6 +239,7 @@ def __bootstrap__(): ...@@ -239,6 +239,7 @@ def __bootstrap__():
core.init_glog(sys.argv[0]) core.init_glog(sys.argv[0])
# don't init_p2p when in unittest to save time. # don't init_p2p when in unittest to save time.
core.init_devices() core.init_devices()
core.eager._init_eager_and_static_tensor_operants()
core.init_default_kernel_signatures() core.init_default_kernel_signatures()
......
...@@ -27,6 +27,8 @@ endif() ...@@ -27,6 +27,8 @@ endif()
py_test(test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py) py_test(test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py)
set_tests_properties(test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180) set_tests_properties(test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180)
py_test(test_custom_power_jit SRCS test_custom_power_jit.py)
set_tests_properties(test_custom_power_jit PROPERTIES TIMEOUT 180)
# CPU custom op tests: only compile .cc file # CPU custom op tests: only compile .cc file
py_test(test_dispatch_jit SRCS test_dispatch_jit.py) py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
std::vector<paddle::Tensor> PowerForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
return {x * x};
} else {
PD_THROW("Not implemented.");
}
}
std::vector<paddle::Tensor> PowerBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor middle_result = grad_out * x;
return {paddle::add(middle_result, middle_result)};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_power)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(PowerForward));
PD_BUILD_GRAD_OP(custom_power)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(PowerBackward));
...@@ -118,7 +118,7 @@ std::vector<paddle::Tensor> relu_xpu_backward(const paddle::Tensor& x, ...@@ -118,7 +118,7 @@ std::vector<paddle::Tensor> relu_xpu_backward(const paddle::Tensor& x,
auto zeros = paddle::experimental::full_like(x, 0.0, x.dtype(), x.place()); auto zeros = paddle::experimental::full_like(x, 0.0, x.dtype(), x.place());
auto condition = paddle::experimental::greater_than(x, zeros); auto condition = paddle::experimental::greater_than(x, zeros);
grad_x = paddle::multiply(grad_out, paddle::where(condition, ones, zeros)); grad_x = grad_out * paddle::where(condition, ones, zeros);
return {grad_x}; return {grad_x};
} }
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from utils import extra_cc_args, paddle_includes
import paddle
import paddle.static as static
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_power_jit\\custom_power_jit.pyd'.format(
get_build_directory()
)
if os.name == 'nt' and os.path.isfile(file):
cmd = 'del {}'.format(file)
run_cmd(cmd, True)
custom_module = load(
name='custom_power_jit',
sources=['custom_power.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cc flags
verbose=True,
)
def custom_power_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
t = paddle.to_tensor(np_x, dtype=dtype)
t.stop_gradient = False
out = func(t) if use_func else paddle.pow(t, 2)
out.stop_gradient = False
out.backward()
if t.grad is None:
return out.numpy(), t.grad
else:
return out.numpy(), t.grad.numpy()
def custom_power_static(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
out = func(x) if use_func else paddle.pow(x, 2)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static graph mode, x data has been covered by out
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static()
return out_v
class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_op = custom_module.custom_power
self.devices = ['cpu']
self.dtypes = ['float32', 'float64']
if paddle.is_compiled_with_cuda():
self.devices.append('gpu')
self.dtypes.append('float16')
def test_static(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = custom_power_static(self.custom_op, device, dtype, x)
pd_out = custom_power_static(
self.custom_op, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, x_grad = custom_power_dynamic(
self.custom_op, device, dtype, x
)
pd_out, pd_x_grad = custom_power_dynamic(
self.custom_op, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
np.testing.assert_allclose(
x_grad, pd_x_grad, rtol=1e-5, atol=1e-8
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册