未验证 提交 5f5a2082 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operants & Prim] Tensor arithmetic operants support right scalar type (#50563)

* polish namespace

* change static_tensor_operants

* polish namespace

* support add, subtract, divide

* add unit test

* polish unittest

* fix cmake error

* solve conflicts, merge auto code-gen

* add scalar operator in tensor.h

* tensorbase

* static prim full support more datatype

* fix prim unittest

* polish codes

* fix cmake error
上级 8476c552
- add
- subtract
- multiply
- divide
- unsqueeze
- pow
- exp
- scale
- multiply
- matmul
- expand
- divide
- sum
- add
- abs
- assign
- concat
......@@ -24,4 +25,3 @@
- scatter_nd_add
- tile
- transpose
- subtract
......@@ -48,6 +48,14 @@ class EagerTensorOperants : public TensorOperantsBase {
public:
EagerTensorOperants() = default;
Tensor add(const Tensor& x, const Scalar& y);
Tensor subtract(const Tensor& x, const Scalar& y);
Tensor multiply(const Tensor& x, const Scalar& y);
Tensor divide(const Tensor& x, const Scalar& y);
"""
......@@ -73,6 +81,22 @@ namespace paddle {
namespace prim {
Tensor EagerTensorOperants::add(const Tensor& x, const Scalar& y) {
return ::add_ad_func(x, ::full_like_ad_func(x, y));
}
Tensor EagerTensorOperants::subtract(const Tensor& x, const Scalar& y) {
return ::subtract_ad_func(x, ::full_like_ad_func(x, y));
}
Tensor EagerTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return ::multiply_ad_func(x, ::full_like_ad_func(x, y));
}
Tensor EagerTensorOperants::divide(const Tensor& x, const Scalar& y) {
return ::divide_ad_func(x, ::full_like_ad_func(x, y));
}
"""
......@@ -112,6 +136,14 @@ class StaticTensorOperants : public TensorOperantsBase {
public:
StaticTensorOperants() = default;
Tensor add(const Tensor& x, const Scalar& y);
Tensor subtract(const Tensor& x, const Scalar& y);
Tensor multiply(const Tensor& x, const Scalar& y);
Tensor divide(const Tensor& x, const Scalar& y);
"""
......@@ -128,6 +160,7 @@ static_source_include = """// Generated by paddle/fluid/prim/api/auto_code_gener
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
"""
......@@ -139,6 +172,22 @@ namespace paddle {
namespace prim {
using DescTensor = paddle::prim::DescTensor;
Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) {
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return paddle::prim::multiply<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
"""
......
......@@ -61,9 +61,7 @@ void gather_grad(const Tensor& x,
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return;
auto tmp = out.pow(2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = grad_out * tmp;
auto grad_x_tmp = grad_out * (out.pow(2.0) * -1.0 + 1.0);
set_output<T>(grad_x_tmp, grad_x);
}
......@@ -203,10 +201,7 @@ void divide_grad(const Tensor& x,
Tensor* dy) {
if (dy) {
// dy = -(x/y^2) * dout
auto tmp0 = y.pow(2.0);
auto tmp1 = x / tmp0;
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
auto dy_res = tmp2 * out_grad;
auto dy_res = x / y.pow(2.0) * -1.0 * out_grad;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
......@@ -247,8 +242,7 @@ void divide_grad(const Tensor& x,
template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto x_grad_tmp = out_grad * div_x / out;
auto x_grad_tmp = out_grad / 2.0 / out;
set_output<T>(x_grad_tmp, x_grad);
}
}
......
......@@ -67,23 +67,51 @@ Tensor full<DescTensor>(const IntArray& shape,
framework::OpDesc* op = block->AppendOp();
op->SetType("fill_constant");
op->SetAttr("shape", shape.GetData());
PADDLE_ENFORCE_EQ(
((dtype == DataType::FLOAT32) || (dtype == DataType::FLOAT64) ||
(dtype == DataType::FLOAT16)),
true,
phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s",
switch (dtype) {
case phi::DataType::FLOAT16:
op->SetAttr("str_value", std::to_string(value.to<float>()));
break;
case phi::DataType::FLOAT32:
op->SetAttr("value", value.to<float>());
break;
case phi::DataType::FLOAT64:
op->SetAttr("str_value", std::to_string(value.to<double>()));
break;
case phi::DataType::BOOL:
op->SetAttr("str_value", std::to_string(value.to<bool>()));
break;
case phi::DataType::INT8:
op->SetAttr("str_value", std::to_string(value.to<int8_t>()));
break;
case phi::DataType::UINT8:
op->SetAttr("str_value", std::to_string(value.to<uint8_t>()));
break;
case phi::DataType::INT16:
op->SetAttr("str_value", std::to_string(value.to<int16_t>()));
break;
case phi::DataType::UINT16:
op->SetAttr("str_value", std::to_string(value.to<uint16_t>()));
break;
case phi::DataType::INT32:
op->SetAttr("str_value", std::to_string(value.to<int32_t>()));
break;
case phi::DataType::UINT32:
op->SetAttr("str_value", std::to_string(value.to<uint32_t>()));
break;
case phi::DataType::INT64:
op->SetAttr("str_value", std::to_string(value.to<int64_t>()));
break;
case phi::DataType::UINT64:
op->SetAttr("str_value", std::to_string(value.to<uint64_t>()));
break;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"We support "
"bool/float16/float32/float64/int8/int16/int32/int64/uint8/uint16/"
"uint32/uint64 for full, but we got data type: %s",
phi::DataTypeToString(dtype)));
if (dtype == phi::DataType::FLOAT32) {
op->SetAttr("value", value.to<float>());
} else if (dtype == phi::DataType::FLOAT64) {
op->SetAttr("str_value", std::to_string(value.to<double>()));
} else if (dtype == phi::DataType::FLOAT16) {
op->SetAttr("str_value", std::to_string(value.to<float>()));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"We only support float64/float32/float16 for full"));
}
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
......
......@@ -29,6 +29,8 @@ cc_test_old(
prim_utils
operator
elementwise_mul_op
elementwise_add_op
fill_constant_op
activation_op
phi_api
phi_dygraph_api
......
......@@ -35,6 +35,7 @@ PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -43,6 +44,7 @@ PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
#endif
......@@ -192,7 +194,7 @@ TEST(StaticPrim, TanhBackwardComposite) {
target_block,
grad_sub_block));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(3));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(6));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
......@@ -210,14 +212,9 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Type(), "scale");
ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0],
grad_ops[0]->Outputs().at("Out")[0]);
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[1]->GetAttr("scale")),
static_cast<float>(-1.0));
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[1]->GetAttr("bias")),
static_cast<float>(1.0));
ASSERT_EQ(grad_ops[1]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[1]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
......@@ -226,9 +223,31 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y")[0],
grad_ops[1]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[3]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Type(), "elementwise_add");
ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0],
grad_ops[3]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[5]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[5]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[5]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[5]->Inputs().at("Y")[0],
grad_ops[4]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[5]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[5]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
}
TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
......@@ -368,8 +387,10 @@ TEST(StaticPrim, TestFlags) {
} // namespace prim
} // namespace paddle
USE_OP_ITSELF(fill_constant);
USE_OP_ITSELF(tanh);
USE_OP_ITSELF(tanh_grad);
USE_OP_ITSELF(pow);
USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(scale);
......@@ -548,6 +548,14 @@ class PADDLE_API Tensor final {
Tensor operator/(const Tensor& other) const;
Tensor operator+(const Scalar& other) const;
Tensor operator-(const Scalar& other) const;
Tensor operator*(const Scalar& other) const;
Tensor operator/(const Scalar& other) const;
/* Part 8: Autograd methods */
/**
......@@ -663,6 +671,11 @@ class PADDLE_API Tensor final {
Tensor divide(const Tensor& y) const;
Tensor multiply(const Tensor& y) const;
Tensor subtract(const Tensor& y) const;
Tensor add(const Scalar& y) const;
Tensor divide(const Scalar& y) const;
Tensor multiply(const Scalar& y) const;
Tensor subtract(const Scalar& y) const;
Tensor exp() const;
Tensor floor() const;
Tensor gather_nd(const Tensor& index) const;
......
......@@ -52,6 +52,14 @@ using IntArray = paddle::experimental::IntArray;
class TensorOperantsBase {
public:
virtual ~TensorOperantsBase() = default;
virtual Tensor add(const Tensor& x, const Scalar& y) = 0;
virtual Tensor divide(const Tensor& x, const Scalar& y) = 0;
virtual Tensor multiply(const Tensor& x, const Scalar& y) = 0;
virtual Tensor subtract(const Tensor& x, const Scalar& y) = 0;
"""
......@@ -90,6 +98,38 @@ Tensor Tensor::operator*(const Tensor &other) const {
Tensor Tensor::operator/(const Tensor &other) const {
return divide(other);
}
Tensor Tensor::operator+(const Scalar &other) const {
return add(other);
}
Tensor Tensor::operator-(const Scalar &other) const {
return subtract(other);
}
Tensor Tensor::operator*(const Scalar &other) const {
return multiply(other);
}
Tensor Tensor::operator/(const Scalar &other) const {
return divide(other);
}
Tensor Tensor::add(const Scalar& y) const {
return paddle::OperantsManager::Instance().add(static_cast<const Tensor &>(*this), y);
}
Tensor Tensor::divide(const Scalar& y) const {
return paddle::OperantsManager::Instance().divide(static_cast<const Tensor &>(*this), y);
}
Tensor Tensor::multiply(const Scalar& y) const {
return paddle::OperantsManager::Instance().multiply(static_cast<const Tensor &>(*this), y);
}
Tensor Tensor::subtract(const Scalar& y) const {
return paddle::OperantsManager::Instance().subtract(static_cast<const Tensor &>(*this), y);
}
"""
......@@ -126,6 +166,15 @@ class PhiTensorOperants : public TensorOperantsBase {
public:
PhiTensorOperants() = default;
Tensor add(const Tensor& x, const Scalar& y);
Tensor subtract(const Tensor& x, const Scalar& y);
Tensor multiply(const Tensor& x, const Scalar& y);
Tensor divide(const Tensor& x, const Scalar& y);
"""
......@@ -150,6 +199,22 @@ operants_source_start = """
namespace paddle {
namespace operants {
Tensor PhiTensorOperants::add(const Tensor& x, const Scalar& y) {
return paddle::experimental::add(x, paddle::experimental::full_like(x, y));
}
Tensor PhiTensorOperants::subtract(const Tensor& x, const Scalar& y) {
return paddle::experimental::subtract(x, paddle::experimental::full_like(x, y));
}
Tensor PhiTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return paddle::experimental::multiply(x, paddle::experimental::full_like(x, y));
}
Tensor PhiTensorOperants::divide(const Tensor& x, const Scalar& y) {
return paddle::experimental::divide(x, paddle::experimental::full_like(x, y));
}
"""
......@@ -225,6 +290,15 @@ class OperantsManager {
public:
static OperantsManager& Instance();
Tensor add(const Tensor& x, const Scalar& y);
Tensor subtract(const Tensor& x, const Scalar& y);
Tensor multiply(const Tensor& x, const Scalar& y);
Tensor divide(const Tensor& x, const Scalar& y);
"""
......@@ -395,17 +469,28 @@ class OperantsAPI(ForwardAPI):
def gene_operants_manager_implementation(self):
func_name = self.get_api_func_name()
final_code = ""
if func_name in ["add", "subtract", "multiply", "divide"]:
final_code += f"""
{self.get_return_type()} OperantsManager::{func_name}(const Tensor& x, const Scalar& y) {{{self.gene_operants_manager_code()}}}
"""
# func decalaration
if func_name[-1] != '_':
return f"""
return (
final_code
+ f"""
{self.get_return_type()} OperantsManager::{func_name}({self.get_define_args()}) {{{self.gene_operants_manager_code()}}}
"""
)
else:
return f"""
return (
final_code
+ f"""
{self.get_return_type(inplace_flag=True)} OperantsManager::{func_name}({self.get_define_args(inplace_flag=True)}) {{
{self.gene_operants_manager_code()}
}}
"""
)
def generate_tensor_operants_api(
......
......@@ -49,6 +49,36 @@ PD_BUILD_GRAD_OP(custom_add)
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(AddBackward));
// y = x + 1
std::vector<paddle::Tensor> ScalarAddForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
return {x + 1};
} else {
PD_THROW("Not implemented.");
}
}
// dy / dx = 1 * grad_out
std::vector<paddle::Tensor> ScalarAddBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
return {grad_out * 1};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_scalar_add)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ScalarAddForward));
PD_BUILD_GRAD_OP(custom_scalar_add)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ScalarAddBackward));
// y = x - 1
std::vector<paddle::Tensor> SubtractForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
......@@ -80,6 +110,37 @@ PD_BUILD_GRAD_OP(custom_subtract)
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(SubtractBackward));
// y = x - 1
std::vector<paddle::Tensor> ScalarSubtractForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
return {x - 1};
} else {
PD_THROW("Not implemented.");
}
}
// dy / dx = 1 * grad_out
std::vector<paddle::Tensor> ScalarSubtractBackward(
const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
return {grad_out * 1};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_scalar_subtract)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ScalarSubtractForward));
PD_BUILD_GRAD_OP(custom_scalar_subtract)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ScalarSubtractBackward));
// y = x * 5
std::vector<paddle::Tensor> MultiplyForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
......@@ -114,6 +175,37 @@ PD_BUILD_GRAD_OP(custom_multiply)
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(MultiplyBackward));
// y = x * 5
std::vector<paddle::Tensor> ScalarMultiplyForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
return {x * 5};
} else {
PD_THROW("Not implemented.");
}
}
// dy / dx = grad_out * 5
std::vector<paddle::Tensor> ScalarMultiplyBackward(
const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
return {grad_out * 5};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_scalar_multiply)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ScalarMultiplyForward));
PD_BUILD_GRAD_OP(custom_scalar_multiply)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ScalarMultiplyBackward));
// y = 1 / x
std::vector<paddle::Tensor> DivideForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
......@@ -145,3 +237,36 @@ PD_BUILD_GRAD_OP(custom_divide)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(DivideBackward));
// y = 1 / x / 1
std::vector<paddle::Tensor> ScalarDivideForward(const paddle::Tensor& x) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor ones = paddle::full(x.shape(), 1.0, x.dtype(), x.place());
return {ones / x / 1};
} else {
PD_THROW("Not implemented.");
}
}
// dy / dx = - (1 / x / x) * grad_out
std::vector<paddle::Tensor> ScalarDivideBackward(
const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu() || x.is_gpu()) {
paddle::Tensor zeros = paddle::full(x.shape(), 0.0, x.dtype(), x.place());
return {zeros - grad_out / (x * x)};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_scalar_divide)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ScalarDivideForward));
PD_BUILD_GRAD_OP(custom_scalar_divide)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ScalarDivideBackward));
......@@ -228,6 +228,16 @@ class TestJITLoad(unittest.TestCase):
self.dtypes.append('float16')
def test_all(self):
self.add = self.custom_module.custom_add
self.subtract = self.custom_module.custom_subtract
self.multiply = self.custom_module.custom_multiply
self.divide = self.custom_module.custom_divide
self._test_static()
self._test_dynamic()
self.add = self.custom_module.custom_scalar_add
self.subtract = self.custom_module.custom_scalar_subtract
self.multiply = self.custom_module.custom_scalar_multiply
self.divide = self.custom_module.custom_scalar_divide
self._test_static()
self._test_dynamic()
......@@ -238,35 +248,31 @@ class TestJITLoad(unittest.TestCase):
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = test_custom_add_static(
self.custom_module.custom_add, device, dtype, x
)
out = test_custom_add_static(self.add, device, dtype, x)
pd_out = test_custom_add_static(
self.custom_module.custom_add, device, dtype, x, False
self.add, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
out = test_custom_subtract_static(
self.custom_module.custom_subtract, device, dtype, x
self.subtract, device, dtype, x
)
pd_out = test_custom_subtract_static(
self.custom_module.custom_subtract, device, dtype, x, False
self.subtract, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
out = test_custom_multiply_static(
self.custom_module.custom_multiply, device, dtype, x
self.multiply, device, dtype, x
)
pd_out = test_custom_multiply_static(
self.custom_module.custom_multiply, device, dtype, x, False
self.multiply, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
out = test_custom_divide_static(
self.custom_module.custom_divide, device, dtype, x
)
out = test_custom_divide_static(self.divide, device, dtype, x)
pd_out = test_custom_divide_static(
self.custom_module.custom_divide, device, dtype, x, False
self.divide, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
......@@ -278,10 +284,10 @@ class TestJITLoad(unittest.TestCase):
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, x_grad = test_custom_add_dynamic(
self.custom_module.custom_add, device, dtype, x
self.add, device, dtype, x
)
pd_out, pd_x_grad = test_custom_add_dynamic(
self.custom_module.custom_add, device, dtype, x, False
self.add, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
np.testing.assert_allclose(
......@@ -289,10 +295,10 @@ class TestJITLoad(unittest.TestCase):
)
out, x_grad = test_custom_subtract_dynamic(
self.custom_module.custom_subtract, device, dtype, x
self.subtract, device, dtype, x
)
pd_out, pd_x_grad = test_custom_subtract_dynamic(
self.custom_module.custom_subtract, device, dtype, x, False
self.subtract, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
np.testing.assert_allclose(
......@@ -300,10 +306,10 @@ class TestJITLoad(unittest.TestCase):
)
out, x_grad = test_custom_multiply_dynamic(
self.custom_module.custom_multiply, device, dtype, x
self.multiply, device, dtype, x
)
pd_out, pd_x_grad = test_custom_multiply_dynamic(
self.custom_module.custom_multiply, device, dtype, x, False
self.multiply, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
np.testing.assert_allclose(
......@@ -311,10 +317,10 @@ class TestJITLoad(unittest.TestCase):
)
out, x_grad = test_custom_divide_dynamic(
self.custom_module.custom_divide, device, dtype, x
self.divide, device, dtype, x
)
pd_out, pd_x_grad = test_custom_divide_dynamic(
self.custom_module.custom_divide, device, dtype, x, False
self.divide, device, dtype, x, False
)
np.testing.assert_allclose(out, pd_out, rtol=1e-5, atol=1e-8)
......
......@@ -41,7 +41,14 @@ from paddle.fluid import core, framework
{'Out': ['y']},
set(),
tuple(),
('pow', 'scale', 'elementwise_mul'),
(
'pow',
'fill_constant',
'elementwise_mul',
'fill_constant',
'elementwise_add',
'elementwise_mul',
),
),
('empty', {}, {'Out': ['y']}, set(), tuple(), tuple()),
),
......
......@@ -27,7 +27,14 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
self.no_grad_var = set()
self.grad_sub_block = tuple()
self.desired_ops = 'tanh_grad'
self.desired_ops_no_skip = ('pow', 'scale', 'elementwise_mul')
self.desired_ops_no_skip = (
'pow',
'fill_constant',
'elementwise_mul',
'fill_constant',
'elementwise_add',
'elementwise_mul',
)
paddle.enable_static()
block = framework.Block(framework.Program(), 0)
block.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册