未验证 提交 3b6ebc9d 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operator] Support add, minus, and divide (#50487)

* polish namespace

* change static_tensor_operants

* polish namespace

* support add, subtract, divide

* add unit test

* polish unittest

* fix cmake error

* polish unittest
上级 383a08e1
......@@ -170,7 +170,7 @@ void divide_grad(const Tensor& x,
if (dy) {
// dy = -(x/y^2) * dout
auto tmp0 = pow<T>(y, 2.0);
auto tmp1 = divide<T>(x, tmp0);
auto tmp1 = x / tmp0;
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
auto dy_res = tmp2 * out_grad;
if (x.dims() != y.dims()) {
......@@ -191,8 +191,7 @@ void divide_grad(const Tensor& x,
if (dx) {
// dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = tmp0 * out_grad;
auto dx_res = one_tensor / y * out_grad;
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
......@@ -215,8 +214,7 @@ template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = out_grad * tmp;
auto x_grad_tmp = out_grad * div_x / out;
set_output<T>(x_grad_tmp, x_grad);
}
}
......
......@@ -21,9 +21,21 @@ namespace paddle {
namespace prim {
Tensor EagerTensorOperants::add(const Tensor& x, const Tensor& y) {
return ::add_ad_func(x, y);
}
Tensor EagerTensorOperants::subtract(const Tensor& x, const Tensor& y) {
return ::subtract_ad_func(x, y);
}
Tensor EagerTensorOperants::multiply(const Tensor& x, const Tensor& y) {
return ::multiply_ad_func(x, y);
}
Tensor EagerTensorOperants::divide(const Tensor& x, const Tensor& y) {
return ::divide_ad_func(x, y);
}
} // namespace prim
} // namespace paddle
......@@ -29,8 +29,14 @@ class EagerTensorOperants : public TensorOperantsBase {
public:
EagerTensorOperants() = default;
Tensor add(const Tensor& x, const Tensor& y) override;
Tensor subtract(const Tensor& x, const Tensor& y) override;
Tensor multiply(const Tensor& x, const Tensor& y) override;
Tensor divide(const Tensor& x, const Tensor& y) override;
private:
DISABLE_COPY_AND_ASSIGN(EagerTensorOperants);
};
......
......@@ -23,9 +23,22 @@ namespace paddle {
namespace prim {
using DescTensor = paddle::prim::DescTensor;
Tensor StaticTensorOperants::add(const Tensor& x, const Tensor& y) {
return paddle::prim::add<DescTensor>(x, y);
}
Tensor StaticTensorOperants::subtract(const Tensor& x, const Tensor& y) {
return paddle::prim::add<DescTensor>(
x, paddle::prim::scale<DescTensor>(y, -1, 0, 0));
}
Tensor StaticTensorOperants::multiply(const Tensor& x, const Tensor& y) {
return paddle::prim::multiply<DescTensor>(x, y);
}
Tensor StaticTensorOperants::divide(const Tensor& x, const Tensor& y) {
return paddle::prim::divide<DescTensor>(x, y);
}
} // namespace prim
} // namespace paddle
......@@ -29,8 +29,14 @@ class StaticTensorOperants : public TensorOperantsBase {
public:
StaticTensorOperants() = default;
Tensor add(const Tensor& x, const Tensor& y) override;
Tensor subtract(const Tensor& x, const Tensor& y) override;
Tensor multiply(const Tensor& x, const Tensor& y) override;
Tensor divide(const Tensor& x, const Tensor& y) override;
private:
DISABLE_COPY_AND_ASSIGN(StaticTensorOperants);
};
......
......@@ -26,7 +26,13 @@ class TensorOperantsBase {
public:
virtual ~TensorOperantsBase() = default;
virtual Tensor add(const Tensor& x, const Tensor& y) = 0;
virtual Tensor subtract(const Tensor& x, const Tensor& y) = 0;
virtual Tensor multiply(const Tensor& x, const Tensor& y) = 0;
virtual Tensor divide(const Tensor& x, const Tensor& y) = 0;
};
} // namespace operants
......
......@@ -60,8 +60,14 @@ class OperantsManager {
public:
static OperantsManager& Instance();
Tensor add(const Tensor& x, const Tensor& y);
Tensor subtract(const Tensor& x, const Tensor& y);
Tensor multiply(const Tensor& x, const Tensor& y);
Tensor divide(const Tensor& x, const Tensor& y);
public:
std::unique_ptr<TensorOperantsBase> eager_operants{nullptr};
std::unique_ptr<TensorOperantsBase> static_operants{nullptr};
......
......@@ -633,7 +633,13 @@ class PADDLE_API Tensor final {
std::string name_{""};
};
PADDLE_API Tensor operator+(const Tensor& x, const Tensor& y);
PADDLE_API Tensor operator-(const Tensor& x, const Tensor& y);
PADDLE_API Tensor operator*(const Tensor& x, const Tensor& y);
PADDLE_API Tensor operator/(const Tensor& x, const Tensor& y);
} // namespace experimental
} // namespace paddle
......@@ -26,8 +26,14 @@ class PhiTensorOperants : public TensorOperantsBase {
public:
PhiTensorOperants() = default;
Tensor add(const Tensor& x, const Tensor& y) override;
Tensor subtract(const Tensor& x, const Tensor& y) override;
Tensor multiply(const Tensor& x, const Tensor& y) override;
Tensor divide(const Tensor& x, const Tensor& y) override;
private:
DISABLE_COPY_AND_ASSIGN(PhiTensorOperants);
};
......
......@@ -28,6 +28,72 @@ OperantsManager& OperantsManager::Instance() {
return g_op_manager;
}
Tensor OperantsManager::add(const Tensor& x, const Tensor& y) {
if (FLAGS_tensor_operants_mode == "eager") {
PADDLE_ENFORCE_NE(
this->eager_operants.get(),
nullptr,
phi::errors::Unavailable("The eager_operants pointer of "
"OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches eager mode";
return this->eager_operants->add(x, y);
} else if (FLAGS_tensor_operants_mode == "static") {
PADDLE_ENFORCE_NE(
this->static_operants.get(),
nullptr,
phi::errors::Unavailable("The static_operants pointer of "
"OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches static mode";
return this->static_operants->add(x, y);
} else if (FLAGS_tensor_operants_mode == "phi") {
PADDLE_ENFORCE_NE(
this->phi_operants.get(),
nullptr,
phi::errors::Unavailable(
"The phi_operants pointer of OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches phi mode";
return this->phi_operants->add(x, y);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"FLAGS_tensor_operants_mode is not nitialized, please set "
"FLAGS_tensor_operants_mode first, which currently supports eager, "
"phi, and static mode"));
}
}
Tensor OperantsManager::subtract(const Tensor& x, const Tensor& y) {
if (FLAGS_tensor_operants_mode == "eager") {
PADDLE_ENFORCE_NE(
this->eager_operants.get(),
nullptr,
phi::errors::Unavailable("The eager_operants pointer of "
"OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches eager mode";
return this->eager_operants->subtract(x, y);
} else if (FLAGS_tensor_operants_mode == "static") {
PADDLE_ENFORCE_NE(
this->static_operants.get(),
nullptr,
phi::errors::Unavailable("The static_operants pointer of "
"OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches static mode";
return this->static_operants->subtract(x, y);
} else if (FLAGS_tensor_operants_mode == "phi") {
PADDLE_ENFORCE_NE(
this->phi_operants.get(),
nullptr,
phi::errors::Unavailable(
"The phi_operants pointer of OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches phi mode";
return this->phi_operants->subtract(x, y);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"FLAGS_tensor_operants_mode is not nitialized, please set "
"FLAGS_tensor_operants_mode first, which currently supports eager, "
"phi, and static mode"));
}
}
Tensor OperantsManager::multiply(const Tensor& x, const Tensor& y) {
if (FLAGS_tensor_operants_mode == "eager") {
PADDLE_ENFORCE_NE(
......@@ -61,4 +127,37 @@ Tensor OperantsManager::multiply(const Tensor& x, const Tensor& y) {
}
}
Tensor OperantsManager::divide(const Tensor& x, const Tensor& y) {
if (FLAGS_tensor_operants_mode == "eager") {
PADDLE_ENFORCE_NE(
this->eager_operants.get(),
nullptr,
phi::errors::Unavailable("The eager_operants pointer of "
"OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches eager mode";
return this->eager_operants->divide(x, y);
} else if (FLAGS_tensor_operants_mode == "static") {
PADDLE_ENFORCE_NE(
this->static_operants.get(),
nullptr,
phi::errors::Unavailable("The static_operants pointer of "
"OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches static mode";
return this->static_operants->divide(x, y);
} else if (FLAGS_tensor_operants_mode == "phi") {
PADDLE_ENFORCE_NE(
this->phi_operants.get(),
nullptr,
phi::errors::Unavailable(
"The phi_operants pointer of OperantsManager is not initialized"));
VLOG(4) << "OperantsManager reaches phi mode";
return this->phi_operants->divide(x, y);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"FLAGS_tensor_operants_mode is not nitialized, please set "
"FLAGS_tensor_operants_mode first, which currently supports eager, "
"phi, and static mode"));
}
}
} // namespace paddle
......@@ -434,9 +434,21 @@ void Tensor::reset_inplace_version(bool set_to_zero) {
}
}
PADDLE_API Tensor operator+(const Tensor &x, const Tensor &y) {
return paddle::OperantsManager::Instance().add(x, y);
}
PADDLE_API Tensor operator-(const Tensor &x, const Tensor &y) {
return paddle::OperantsManager::Instance().subtract(x, y);
}
PADDLE_API Tensor operator*(const Tensor &x, const Tensor &y) {
return paddle::OperantsManager::Instance().multiply(x, y);
}
PADDLE_API Tensor operator/(const Tensor &x, const Tensor &y) {
return paddle::OperantsManager::Instance().divide(x, y);
}
} // namespace experimental
} // namespace paddle
......@@ -20,9 +20,21 @@ namespace paddle {
namespace operants {
Tensor PhiTensorOperants::add(const Tensor& x, const Tensor& y) {
return paddle::experimental::add(x, y);
}
Tensor PhiTensorOperants::subtract(const Tensor& x, const Tensor& y) {
return paddle::experimental::subtract(x, y);
}
Tensor PhiTensorOperants::multiply(const Tensor& x, const Tensor& y) {
return paddle::experimental::multiply(x, y);
}
Tensor PhiTensorOperants::divide(const Tensor& x, const Tensor& y) {
return paddle::experimental::divide(x, y);
}
} // namespace operants
} // namespace paddle
......@@ -27,8 +27,8 @@ endif()
py_test(test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py)
set_tests_properties(test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180)
py_test(test_custom_power_jit SRCS test_custom_power_jit.py)
set_tests_properties(test_custom_power_jit PROPERTIES TIMEOUT 180)
py_test(test_custom_tensor_operator SRCS test_custom_tensor_operator.py)
set_tests_properties(test_custom_tensor_operator PROPERTIES TIMEOUT 180)
# CPU custom op tests: only compile .cc file
py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
std::vector<paddle::Tensor> PowerForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
return {x * x};
} else {
PD_THROW("Not implemented.");
}
}
std::vector<paddle::Tensor> PowerBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor middle_result = grad_out * x;
return {paddle::add(middle_result, middle_result)};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_power)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(PowerForward));
PD_BUILD_GRAD_OP(custom_power)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(PowerBackward));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
// y = x + 1
std::vector<paddle::Tensor> AddForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor ones = paddle::full(x.shape(), 1.0, x.dtype(), x.place());
return {x + ones};
} else {
PD_THROW("Not implemented.");
}
}
// dy / dx = 1 * grad_out
std::vector<paddle::Tensor> AddBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor ones = paddle::full(x.shape(), 1.0, x.dtype(), x.place());
return {grad_out * ones};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_add)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(AddForward));
PD_BUILD_GRAD_OP(custom_add)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(AddBackward));
// y = x - 1
std::vector<paddle::Tensor> SubtractForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor ones = paddle::full(x.shape(), 1, x.dtype(), x.place());
return {x - ones};
} else {
PD_THROW("Not implemented.");
}
}
// dy / dx = 1 * grad_out
std::vector<paddle::Tensor> SubtractBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
return {grad_out};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_subtract)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(SubtractForward));
PD_BUILD_GRAD_OP(custom_subtract)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(SubtractBackward));
// y = x * 5
std::vector<paddle::Tensor> MultiplyForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor ones = paddle::full(x.shape(), 1.0, x.dtype(), x.place());
paddle::Tensor fives = paddle::experimental::fill(ones, 5);
return {x * fives};
} else {
PD_THROW("Not implemented.");
}
}
// dy / dx = 5 * grad_out
std::vector<paddle::Tensor> MultiplyBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor ones = paddle::full(x.shape(), 1.0, x.dtype(), x.place());
paddle::Tensor fives = paddle::experimental::fill(ones, 5);
return {fives * grad_out};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_multiply)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(MultiplyForward));
PD_BUILD_GRAD_OP(custom_multiply)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(MultiplyBackward));
// y = 1 / x
std::vector<paddle::Tensor> DivideForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor ones = paddle::full(x.shape(), 1.0, x.dtype(), x.place());
return {ones / x};
} else {
PD_THROW("Not implemented.");
}
}
// dy / dx = - (1 / x / x) * grad_out
std::vector<paddle::Tensor> DivideBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor zeros = paddle::full(x.shape(), 0.0, x.dtype(), x.place());
return {zeros - grad_out / (x * x)};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_divide)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DivideForward));
PD_BUILD_GRAD_OP(custom_divide)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(DivideBackward));
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from utils import extra_cc_args, paddle_includes
import paddle
import paddle.static as static
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_power_jit\\custom_power_jit.pyd'.format(
get_build_directory()
)
if os.name == 'nt' and os.path.isfile(file):
cmd = 'del {}'.format(file)
run_cmd(cmd, True)
custom_module = load(
name='custom_power_jit',
sources=['custom_power.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cc flags
verbose=True,
)
def custom_power_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
t = paddle.to_tensor(np_x, dtype=dtype)
t.stop_gradient = False
out = func(t) if use_func else paddle.pow(t, 2)
out.stop_gradient = False
out.backward()
if t.grad is None:
return out.numpy(), t.grad
else:
return out.numpy(), t.grad.numpy()
def custom_power_static(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
out = func(x) if use_func else paddle.pow(x, 2)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static graph mode, x data has been covered by out
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static()
return out_v
class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_op = custom_module.custom_power
self.devices = ['cpu']
self.dtypes = ['float32', 'float64']
if paddle.is_compiled_with_cuda():
self.devices.append('gpu')
self.dtypes.append('float16')
def test_static(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = custom_power_static(self.custom_op, device, dtype, x)
pd_out = custom_power_static(
self.custom_op, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, x_grad = custom_power_dynamic(
self.custom_op, device, dtype, x
)
pd_out, pd_x_grad = custom_power_dynamic(
self.custom_op, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
np.testing.assert_allclose(
x_grad, pd_x_grad, rtol=1e-5, atol=1e-8
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from utils import extra_cc_args, paddle_includes
import paddle
import paddle.static as static
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_tensor_operator\\custom_tensor_operator.pyd'.format(
get_build_directory()
)
if os.name == 'nt' and os.path.isfile(file):
cmd = 'del {}'.format(file)
run_cmd(cmd, True)
def test_custom_add_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype)
x.stop_gradient = False
if use_func:
out = func(x)
else:
out = x + 1
out.stop_gradient = False
out.backward()
if x.grad is None:
return out.numpy(), x.grad
else:
return out.numpy(), x.grad.numpy()
def test_custom_add_static(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
if use_func:
out = func(x)
else:
out = x + 1
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static graph mode, x data has been covered by out
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static()
return out_v
def test_custom_subtract_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype)
x.stop_gradient = False
if use_func:
out = func(x)
else:
out = x - 1
out.stop_gradient = False
out.backward()
if x.grad is None:
return out.numpy(), x.grad
else:
return out.numpy(), x.grad.numpy()
def test_custom_subtract_static(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
if use_func:
out = func(x)
else:
out = x - 1
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static graph mode, x data has been covered by out
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static()
return out_v
def test_custom_multiply_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype)
x.stop_gradient = False
if use_func:
out = func(x)
else:
out = x * 5
out.stop_gradient = False
out.backward()
if x.grad is None:
return out.numpy(), x.grad
else:
return out.numpy(), x.grad.numpy()
def test_custom_multiply_static(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
if use_func:
out = func(x)
else:
out = x * 5
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static graph mode, x data has been covered by out
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static()
return out_v
def test_custom_divide_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype)
x.stop_gradient = False
if use_func:
out = func(x)
else:
out = paddle.reciprocal(x)
out.stop_gradient = False
out.backward()
if x.grad is None:
return out.numpy(), x.grad
else:
return out.numpy(), x.grad.numpy()
def test_custom_divide_static(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[4, 8], dtype=dtype)
x.stop_gradient = False
if use_func:
out = func(x)
else:
out = paddle.reciprocal(x)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static graph mode, x data has been covered by out
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static()
return out_v
class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_module = load(
name='custom_tensor_operator',
sources=['custom_tensor_operator.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cc flags
verbose=True,
)
self.devices = ['cpu']
self.dtypes = ['float32', 'float64']
if paddle.is_compiled_with_cuda():
self.devices.append('gpu')
self.dtypes.append('float16')
def test_all(self):
self._test_static()
self._test_dynamic()
def _test_static(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = test_custom_add_static(
self.custom_module.custom_add, device, dtype, x
)
pd_out = test_custom_add_static(
self.custom_module.custom_add, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
out = test_custom_subtract_static(
self.custom_module.custom_subtract, device, dtype, x
)
pd_out = test_custom_subtract_static(
self.custom_module.custom_subtract, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
out = test_custom_multiply_static(
self.custom_module.custom_multiply, device, dtype, x
)
pd_out = test_custom_multiply_static(
self.custom_module.custom_multiply, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
out = test_custom_divide_static(
self.custom_module.custom_divide, device, dtype, x
)
pd_out = test_custom_divide_static(
self.custom_module.custom_divide, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
def _test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, x_grad = test_custom_add_dynamic(
self.custom_module.custom_add, device, dtype, x
)
pd_out, pd_x_grad = test_custom_add_dynamic(
self.custom_module.custom_add, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
np.testing.assert_allclose(
x_grad, pd_x_grad, rtol=1e-5, atol=1e-8
)
out, x_grad = test_custom_subtract_dynamic(
self.custom_module.custom_subtract, device, dtype, x
)
pd_out, pd_x_grad = test_custom_subtract_dynamic(
self.custom_module.custom_subtract, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
np.testing.assert_allclose(
x_grad, pd_x_grad, rtol=1e-5, atol=1e-8
)
out, x_grad = test_custom_multiply_dynamic(
self.custom_module.custom_multiply, device, dtype, x
)
pd_out, pd_x_grad = test_custom_multiply_dynamic(
self.custom_module.custom_multiply, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
np.testing.assert_allclose(
x_grad, pd_x_grad, rtol=1e-5, atol=1e-8
)
out, x_grad = test_custom_divide_dynamic(
self.custom_module.custom_divide, device, dtype, x
)
pd_out, pd_x_grad = test_custom_divide_dynamic(
self.custom_module.custom_divide, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册