未验证 提交 98303291 编写于 作者: H hong 提交者: GitHub

Add basic yaml backward (#40751)

* fix error; test=develop

* update

* close some yaml

* fix backward attrite error; test=develop

* add div test

* polish code; test=develop

* update

* update

* fix bug

* update bitwise code; test=develop

* update

* update

* fix some bug

* update

* revert cmakelist

* fix optional bug;

* fix bug

* fix bug;

* add backward test

* open bn

* update

* update

* revert eager_gen

* polish code

* fix topk error

* update

* update

* fix bug;

* move label smooth, nll loss

* revert topk

* fix topk label smooth bug;

* remove batch_norm

* remove topk

* change flip infer meta

* fix flip bug

* update yaml

* close abs

* fix histogram bug

* fix histogram bug

* add abs

* fix histogram kernel

* remove expand
上级 2f41f389
......@@ -68,9 +68,9 @@ void IndexSampleGradInner(const Context& context,
template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto index_type = index.dtype();
bool index_type_match =
......
......@@ -21,9 +21,9 @@ namespace phi {
template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto* mask_data = mask.data<bool>();
auto* input_data = out_grad.data<T>();
......
......@@ -121,8 +121,8 @@ template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& labels,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight,
const DenseTensor& total_weight,
const DenseTensor& d_out,
int64_t ignore_index,
const std::string& reduction,
......
......@@ -51,17 +51,17 @@ static void FullTopKAssign(const Type& input_height,
template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& indices,
int k,
const DenseTensor& out_grad,
const Scalar& k_scalar,
int axis,
bool largest,
bool sorted,
DenseTensor* x_grad) {
const auto& in_dims = x.dims();
const auto& out_dims = indices.dims();
int k = k_scalar.to<int>();
// axis < 0, get the real axis
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
......
......@@ -36,7 +36,7 @@ void LimitGridDim(const Context& ctx, dim3* grid_dim) {
#define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b))
};
} // namespace
template <typename T, typename IndexT = int>
__global__ void IndexSampleGrad(const IndexT* index,
......@@ -67,9 +67,9 @@ __global__ void IndexSampleGrad(const IndexT* index,
template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const T* output_grad_data = out_grad.data<T>();
T* input_grad_data = ctx.template Alloc<T>(x_grad);
......
......@@ -44,9 +44,9 @@ struct MaskedSelectGradFunctor {
template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto mask_size = mask.numel();
dev_ctx.template Alloc<T>(x_grad);
......
......@@ -23,8 +23,8 @@ template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& labels,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight,
const DenseTensor& total_weight,
const DenseTensor& dout,
int64_t ignore_index,
const std::string& reduction,
......
......@@ -25,10 +25,10 @@ namespace ops = paddle::operators;
template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& indices,
int k,
const DenseTensor& out_grad,
const Scalar& k_scalar,
int axis,
bool largest,
bool sorted,
......@@ -36,6 +36,8 @@ void TopkGradKernel(const Context& dev_ctx,
const auto& in_dims = x.dims();
const auto& out_dims = indices.dims();
int k = k_scalar.to<int>();
// get the real the axis and the k
if (axis < 0) {
axis += in_dims.size();
......
......@@ -18,11 +18,11 @@
namespace phi {
template <typename T, typename Context>
void HistogramSelectKernel(const Context& dev_ctx,
const DenseTensor& input,
int64_t bins,
int min,
int max,
DenseTensor* out);
void HistogramKernel(const Context& dev_ctx,
const DenseTensor& input,
int64_t bins,
int min,
int max,
DenseTensor* output);
} // namespace phi
......@@ -20,9 +20,9 @@ namespace phi {
template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
DenseTensor* in_grad);
} // namespace phi
......@@ -19,9 +19,9 @@ namespace phi {
template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namspace phi
......@@ -22,8 +22,8 @@ template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight,
const DenseTensor& total_weight,
const DenseTensor& d_out,
int64_t ignore_index,
const std::string& reduction,
......
......@@ -14,16 +14,17 @@
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& indices,
int k,
const DenseTensor& out_grad,
const Scalar& k,
int axis,
bool largest,
bool sorted,
......
......@@ -19,7 +19,7 @@ namespace phi {
KernelSignature IndexSampleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("index_sample_grad",
{GradVarName("Out"), "X", "Index"},
{"X", "Index", GradVarName("Out")},
{},
{GradVarName("X")});
}
......
......@@ -24,7 +24,7 @@ KernelSignature MaskedSelectOpArgumentMapping(
KernelSignature MaskedSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("masked_select_grad",
{GradVarName("Y"), "X", "Mask"},
{"X", "Mask", GradVarName("Y")},
{},
{GradVarName("X")});
}
......
......@@ -29,7 +29,7 @@ KernelSignature NllLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"nll_loss_grad",
{"X", "Label", "Total_weight", "Weight", GradVarName("Out")},
{"X", "Label", "Weight", "Total_weight", GradVarName("Out")},
{"ignore_index", "reduction"},
{GradVarName("X")});
}
......
......@@ -29,7 +29,7 @@ KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("top_k_grad",
{GradVarName("Out"), "X", "Indices"},
{"X", "Indices", GradVarName("Out")},
{"k", "axis", "largest", "sorted"},
{GradVarName("X")});
}
......
......@@ -12529,6 +12529,9 @@ def logical_and(x, y, out=None, name=None):
res = paddle.logical_and(x, y)
print(res) # [True False True False]
"""
if in_dygraph_mode():
return _C_ops.final_state_logical_and(x, y)
return _logical_op(
op_name="logical_and", x=x, y=y, name=name, out=out, binary_op=True)
......@@ -12568,6 +12571,8 @@ def logical_or(x, y, out=None, name=None):
res = paddle.logical_or(x, y)
print(res) # [[ True True] [ True False]]
"""
if in_dygraph_mode():
return _C_ops.final_state_logical_or(x, y)
return _logical_op(
op_name="logical_or", x=x, y=y, name=name, out=out, binary_op=True)
......@@ -12607,6 +12612,9 @@ def logical_xor(x, y, out=None, name=None):
res = paddle.logical_xor(x, y)
print(res) # [[False, True], [ True, False]]
"""
if in_dygraph_mode():
return _C_ops.final_state_logical_xor(x, y)
return _logical_op(
op_name="logical_xor", x=x, y=y, name=name, out=out, binary_op=True)
......@@ -12639,7 +12647,8 @@ def logical_not(x, out=None, name=None):
res = paddle.logical_not(x)
print(res) # [False True False True]
"""
if in_dygraph_mode():
return _C_ops.final_state_logical_not(x)
return _logical_op(
op_name="logical_not", x=x, y=None, name=name, out=out, binary_op=False)
......
......@@ -19,7 +19,7 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from op_test import OpTest, _set_use_system_allocator
from paddle.fluid.framework import grad_var_name
from paddle.fluid.framework import grad_var_name, _test_eager_guard
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle
......@@ -46,32 +46,32 @@ class TestBatchNorm(unittest.TestCase):
def error1d_dataformat():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm1d = paddle.nn.BatchNorm1D(1, data_format='NCDHW')
batch_norm1d(fluid.dygraph.to_variable(x_data_4))
batch_norm1d(paddle.to_tensor(x_data_4))
def error2d_dataformat():
x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32')
batch_norm2d = paddle.nn.BatchNorm2D(1, data_format='NCDHW')
batch_norm2d(fluid.dygraph.to_variable(x_data_3))
batch_norm2d(paddle.to_tensor(x_data_3))
def error3d_dataformat():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm3d = paddle.nn.BatchNorm3D(1, data_format='NCL')
batch_norm3d(fluid.dygraph.to_variable(x_data_4))
batch_norm3d(paddle.to_tensor(x_data_4))
def error1d():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm1d = paddle.nn.BatchNorm1D(1)
batch_norm1d(fluid.dygraph.to_variable(x_data_4))
batch_norm1d(paddle.to_tensor(x_data_4))
def error2d():
x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32')
batch_norm2d = paddle.nn.BatchNorm2D(1)
batch_norm2d(fluid.dygraph.to_variable(x_data_3))
batch_norm2d(paddle.to_tensor(x_data_3))
def error3d():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm3d = paddle.nn.BatchNorm3D(1)
batch_norm3d(fluid.dygraph.to_variable(x_data_4))
batch_norm3d(paddle.to_tensor(x_data_4))
with fluid.dygraph.guard(p):
self.assertRaises(ValueError, error1d)
......@@ -94,13 +94,18 @@ class TestBatchNorm(unittest.TestCase):
shape[1],
is_test=is_test,
trainable_statistics=trainable_statistics)
y = bn(fluid.dygraph.to_variable(x))
y = bn(paddle.to_tensor(x))
return y.numpy()
def compute_v2(x):
with fluid.dygraph.guard(p):
bn = paddle.nn.BatchNorm2D(shape[1])
y = bn(fluid.dygraph.to_variable(x))
y = bn(paddle.to_tensor(x))
with _test_eager_guard():
bn = paddle.nn.BatchNorm2D(shape[1])
eag_y = bn(paddle.to_tensor(x))
assert np.allclose(eag_y.numpy(), y.numpy())
return y.numpy()
def compute_v3(x, is_test, trainable_statistics):
......@@ -115,14 +120,14 @@ class TestBatchNorm(unittest.TestCase):
initializer=fluid.initializer.Constant(0.0),
trainable=False),
trainable_statistics=trainable_statistics)
y = bn(fluid.dygraph.to_variable(x))
y = bn(paddle.to_tensor(x))
return y.numpy()
def compute_v4(x):
with fluid.dygraph.guard(p):
bn = paddle.nn.BatchNorm2D(
shape[1], weight_attr=False, bias_attr=False)
y = bn(fluid.dygraph.to_variable(x))
y = bn(paddle.to_tensor(x))
return y.numpy()
x = np.random.randn(*shape).astype("float32")
......
......@@ -32,6 +32,7 @@ class ElementwiseDivOp(OpTest):
'X': np.random.random((32,84)).astype("float32"),
'Y': np.random.random((32,84)).astype("float32")
"""
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
......@@ -39,7 +40,7 @@ class ElementwiseDivOp(OpTest):
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
def check_eager(self):
return (self.use_mkldnn == False and self.axis == -1)
return (not hasattr(self, "attrs") or (self.attrs["axis"] != -1))
def test_check_output(self):
self.check_output(check_eager=False)
......@@ -65,6 +66,7 @@ class ElementwiseDivOp(OpTest):
class TestElementwiseDivOpBF16(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32)
......@@ -100,6 +102,7 @@ class TestElementwiseDivOpBF16(OpTest):
class TestElementwiseDivOp_scalar(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [20, 3, 4]).astype(np.float64),
'Y': np.random.uniform(0.1, 1, [1]).astype(np.float64)
......@@ -110,6 +113,7 @@ class TestElementwiseDivOp_scalar(ElementwiseDivOp):
class TestElementwiseDivOp_Vector(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
......@@ -120,6 +124,7 @@ class TestElementwiseDivOp_Vector(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [100, 3, 4]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
......@@ -135,6 +140,7 @@ class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 100, 4]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
......@@ -150,6 +156,7 @@ class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
......@@ -164,6 +171,7 @@ class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 10, 12, 5]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [10, 12]).astype("float64")
......@@ -179,6 +187,7 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 50]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 1, 50]).astype("float64")
......@@ -189,6 +198,7 @@ class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4, 20]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 3, 1, 20]).astype("float64")
......@@ -199,6 +209,7 @@ class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [1, 1, 100]).astype("float64"),
......@@ -209,6 +220,7 @@ class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp):
class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [30, 3, 1, 5]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [30, 1, 4, 1]).astype("float64"),
......@@ -219,6 +231,7 @@ class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp):
class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 12]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 3, 10, 12]).astype("float64"),
......@@ -232,6 +245,7 @@ class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp):
class TestElementwiseDivOp_INT(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.dtype = np.int32
self.init_dtype()
self.inputs = {
......@@ -304,6 +318,7 @@ class TestDivideOp(unittest.TestCase):
class TestComplexElementwiseDivOp(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
......@@ -334,7 +349,7 @@ class TestComplexElementwiseDivOp(OpTest):
self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y)
def test_check_output(self):
self.check_output()
self.check_output(check_eager=False)
def test_check_grad_normal(self):
self.check_grad(
......
......@@ -24,6 +24,7 @@ import paddle.fluid as fluid
class TestExpandAsOpRank1(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
self.python_api = paddle.expand_as
x = np.random.rand(100).astype("float64")
target_tensor = np.random.rand(2, 100).astype("float64")
self.inputs = {'X': x}
......
......@@ -21,6 +21,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard
class TestHistogramOpAPI(unittest.TestCase):
......@@ -57,6 +58,15 @@ class TestHistogramOpAPI(unittest.TestCase):
(actual.numpy() == expected).all(),
msg='histogram output is wrong, out =' + str(actual.numpy()))
with _test_eager_guard():
inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64)
inputs = paddle.to_tensor(inputs_np)
actual = paddle.histogram(inputs, bins=5, min=1, max=5)
self.assertTrue(
(actual.numpy() == expected).all(),
msg='histogram output is wrong, out =' +
str(actual.numpy()))
class TestHistogramOpError(unittest.TestCase):
"""Test histogram op error."""
......@@ -118,6 +128,7 @@ class TestHistogramOp(OpTest):
self.op_type = "histogram"
self.init_test_case()
np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape)
self.python_api = paddle.histogram
self.inputs = {"X": np_input}
self.init_attrs()
Out, _ = np.histogram(
......@@ -134,7 +145,7 @@ class TestHistogramOp(OpTest):
self.attrs = {"bins": self.bins, "min": self.min, "max": self.max}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
if __name__ == "__main__":
......
......@@ -40,10 +40,10 @@ class TestIndexSampleOp(OpTest):
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output(check_eager=False)
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=False)
self.check_grad(['X'], 'Out', check_eager=True)
def config(self):
"""
......
......@@ -16,6 +16,7 @@ import paddle
import paddle.fluid as fluid
import unittest
import numpy as np
from paddle.fluid.framework import _test_eager_guard
def run_static(x_np, dtype, op_str, use_gpu=False):
......@@ -46,6 +47,18 @@ def run_dygraph(x_np, op_str, use_gpu=True):
return dygraph_result
def run_eager(x_np, op_str, use_gpu=True):
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
place = paddle.CPUPlace()
if use_gpu and fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
x = paddle.to_tensor(x_np)
dygraph_result = getattr(paddle.tensor, op_str)(x)
return dygraph_result
def np_data_generator(low, high, np_shape, type, sv_list, op_str, *args,
**kwargs):
x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type))
......@@ -107,8 +120,10 @@ def test(test_case, op_str, use_gpu=False):
x_np, result_np = np_data_generator(**meta_data)
static_result = run_static(x_np, meta_data['type'], op_str, use_gpu)
dygraph_result = run_dygraph(x_np, op_str, use_gpu)
eager_result = run_eager(x_np, op_str, use_gpu)
test_case.assertTrue((static_result == result_np).all())
test_case.assertTrue((dygraph_result.numpy() == result_np).all())
test_case.assertTrue((eager_result.numpy() == result_np).all())
class TestCPUNormal(unittest.TestCase):
......@@ -158,4 +173,5 @@ class TestError(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -27,6 +27,7 @@ np.random.seed(0)
class TestLerp(OpTest):
def setUp(self):
self.op_type = "lerp"
self.python_api = paddle.lerp
self.init_dtype()
self.init_shape()
x = np.arange(1., 101.).astype(self.dtype).reshape(self.shape)
......@@ -42,10 +43,10 @@ class TestLerp(OpTest):
self.shape = [100]
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
class TestLerpWithDim2(TestLerp):
......
......@@ -20,6 +20,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.static import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
SUPPORTED_DTYPES = [
bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64
......@@ -144,6 +145,22 @@ def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True):
return dygraph_result
def run_eager(x_np, y_np, op_str, use_gpu=False, binary_op=True):
place = paddle.CPUPlace()
if use_gpu and fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
paddle.disable_static(place)
with _test_eager_guard():
op = getattr(paddle, op_str)
x = paddle.to_tensor(x_np, dtype=x_np.dtype)
if not binary_op:
dygraph_result = op(x)
else:
y = paddle.to_tensor(y_np, dtype=y_np.dtype)
dygraph_result = op(x, y)
return dygraph_result
def np_data_generator(np_shape, dtype, *args, **kwargs):
if dtype == bool:
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
......@@ -174,6 +191,7 @@ def test(unit_test, use_gpu=False, test_error=False):
continue
static_result = run_static(**meta_data)
dygraph_result = run_dygraph(**meta_data)
eager_result = run_eager(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
......@@ -181,6 +199,7 @@ def test(unit_test, use_gpu=False, test_error=False):
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all(
))
unit_test.assertTrue((eager_result.numpy() == np_result).all())
def test_type_error(unit_test, use_gpu, type_str_map):
......@@ -259,4 +278,5 @@ class TestCUDA(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -33,6 +33,7 @@ class TestMaskedSelectOp(OpTest):
def setUp(self):
self.init()
self.op_type = "masked_select"
self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype("float64")
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(x, mask)
......@@ -40,10 +41,10 @@ class TestMaskedSelectOp(OpTest):
self.outputs = {'Y': out}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(['X'], 'Y', check_eager=True)
def init(self):
self.shape = (50, 3)
......@@ -121,4 +122,5 @@ class TestMaskedSelectError(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import numpy as np
import unittest
from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard
def nll_loss_1d(logs, targets, weight=None, reduction='mean',
......@@ -97,14 +98,21 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss()
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
with fluid.dygraph.guard():
with _test_eager_guard():
nll_loss = paddle.nn.loss.NLLLoss()
eager_res = nll_loss(
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
eager_result = eager_res.numpy()
expected = nll_loss_1d(input_np, label_np)[0]
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
self.assertTrue(np.allclose(eager_result, expected))
def test_NLLLoss_1D_sum(self):
np.random.seed(200)
......@@ -132,14 +140,24 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
with _test_eager_guard():
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
in_t = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
in_t.stop_gradient = False
eager_res = nll_loss(in_t, label)
eager_result = eager_res.numpy()
loss = eager_res.sum()
loss.backward()
expected = nll_loss_1d(input_np, label_np, reduction='sum')[0]
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
self.assertTrue(np.allclose(eager_result, expected))
def test_NLLLoss_1D_with_weight_mean(self):
np.random.seed(200)
......@@ -170,16 +188,26 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
weight=paddle.to_tensor(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
with _test_eager_guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=paddle.to_tensor(weight_np))
eager_res = nll_loss(
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
loss = eager_res.sum()
loss.backward()
eager_result = eager_res.numpy()
expected = nll_loss_1d(input_np, label_np, weight=weight_np)[0]
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
self.assertTrue(np.allclose(eager_result, expected))
def test_NLLLoss_1D_with_weight_sum(self):
np.random.seed(200)
......@@ -210,10 +238,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum')
weight=paddle.to_tensor(weight_np), reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
expected = nll_loss_1d(
input_np, label_np, weight=weight_np, reduction='sum')[0]
......@@ -249,10 +276,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
weight=paddle.to_tensor(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
expected = nll_loss_1d(input_np, label_np, weight=weight_np)[0]
......@@ -287,10 +313,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none')
weight=paddle.to_tensor(weight_np), reduction='none')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
expected = nll_loss_1d(
input_np, label_np, weight=weight_np, reduction='none')
......@@ -326,8 +351,7 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss()
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
expected = nll_loss_2d(input_np, label_np)[0]
......@@ -363,8 +387,7 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
expected = nll_loss_2d(input_np, label_np, reduction='sum')[0]
......@@ -404,10 +427,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
weight=paddle.to_tensor(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
expected = nll_loss_2d(input_np, label_np, weight=weight_np)[0]
......@@ -445,10 +467,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
weight=paddle.to_tensor(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
expected = nll_loss_2d(input_np, label_np, weight=weight_np)[0]
......@@ -487,10 +508,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum')
weight=paddle.to_tensor(weight_np), reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
expected = nll_loss_2d(
......@@ -527,8 +547,7 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss()
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
input_shape = input_np.shape
......@@ -572,10 +591,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
weight=paddle.to_tensor(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
input_shape = input_np.shape
......@@ -620,10 +638,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum')
weight=paddle.to_tensor(weight_np), reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
input_shape = input_np.shape
......@@ -671,10 +688,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none')
weight=paddle.to_tensor(weight_np), reduction='none')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
input_shape = input_np.shape
......@@ -721,10 +737,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none')
weight=paddle.to_tensor(weight_np), reduction='none')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
input_shape = input_np.shape
......@@ -749,6 +764,8 @@ class TestNLLLossOp1DWithReduce(OpTest):
self.init_test_case()
self.op_type = "nll_loss"
self.with_weight = False
self.python_api = paddle.nn.functional.nll_loss
self.python_out_sig = ["Out"]
np.random.seed(200)
input_np = np.random.uniform(0.1, 0.8,
self.input_shape).astype("float64")
......@@ -769,7 +786,7 @@ class TestNLLLossOp1DWithReduce(OpTest):
self.attrs = {'reduction': 'mean', 'ignore_index': -100}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=False)
def test_check_output_with_weight(self):
self.with_weight = True
......@@ -778,7 +795,7 @@ class TestNLLLossOp1DWithReduce(OpTest):
def test_check_grad(self):
self.with_weight = True
place = fluid.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_eager=False)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
......@@ -1014,4 +1031,5 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle
class TestTopkOp(OpTest):
......@@ -61,4 +62,5 @@ class TestTopkOp(OpTest):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
def numpy_topk(x, k=1, axis=-1, largest=True):
......@@ -45,6 +46,7 @@ class TestTopkOp(OpTest):
def setUp(self):
self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 20)
self.init_args()
......@@ -55,12 +57,10 @@ class TestTopkOp(OpTest):
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
paddle.enable_static()
self.check_output()
self.check_output(check_eager=False)
def test_check_grad(self):
paddle.enable_static()
self.check_grad(set(['X']), 'Out')
self.check_grad(set(['X']), 'Out', check_eager=False)
class TestTopkOp1(TestTopkOp):
......@@ -85,6 +85,7 @@ class TestTopkOp3(OpTest):
def setUp(self):
self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(16, 100)
self.init_args()
......@@ -103,6 +104,7 @@ class TestTopkOp4(TestTopkOp):
def setUp(self):
self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
self.init_args()
......@@ -121,6 +123,7 @@ class TestTopkOp5(TestTopkOp):
def setUp(self):
self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
self.init_args()
......@@ -139,6 +142,7 @@ class TestTopkOp6(OpTest):
def setUp(self):
self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(80, 16384)
self.init_args()
......@@ -156,48 +160,64 @@ class TestTopKAPI(unittest.TestCase):
self.large_input_data = np.random.rand(2, 1030)
def run_dygraph(self, place):
paddle.disable_static(place)
input_tensor = paddle.to_tensor(self.input_data)
large_input_tensor = paddle.to_tensor(self.large_input_data)
# test case for basic test case 1
paddle_result = paddle.topk(input_tensor, k=2)
numpy_result = numpy_topk(self.input_data, k=2)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 2 with axis
paddle_result = paddle.topk(input_tensor, k=2, axis=1)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 3 with tensor K
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=k_tensor, axis=1)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 4 with tensor largest
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=2, axis=1, largest=False)
numpy_result = numpy_topk(self.input_data, k=2, axis=1, largest=False)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 5 with axis -1
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=2, axis=-1, largest=False)
numpy_result = numpy_topk(self.input_data, k=2, axis=-1, largest=False)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 6 for the partial sort
paddle_result = paddle.topk(large_input_tensor, k=1, axis=-1)
numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 7 for the unsorted
paddle_result = paddle.topk(input_tensor, k=2, axis=1, sorted=False)
sort_paddle = numpy_topk(
np.array(paddle_result[0].numpy()), axis=1, k=2)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0]))
with paddle.fluid.dygraph.guard(place):
input_tensor = paddle.to_tensor(self.input_data)
large_input_tensor = paddle.to_tensor(self.large_input_data)
# test case for basic test case 1
paddle_result = paddle.topk(input_tensor, k=2)
numpy_result = numpy_topk(self.input_data, k=2)
self.assertTrue(
np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(
np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 2 with axis
paddle_result = paddle.topk(input_tensor, k=2, axis=1)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(
np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(
np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 3 with tensor K
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=k_tensor, axis=1)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(
np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(
np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 4 with tensor largest
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(
input_tensor, k=2, axis=1, largest=False)
numpy_result = numpy_topk(
self.input_data, k=2, axis=1, largest=False)
self.assertTrue(
np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(
np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 5 with axis -1
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(
input_tensor, k=2, axis=-1, largest=False)
numpy_result = numpy_topk(
self.input_data, k=2, axis=-1, largest=False)
self.assertTrue(
np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(
np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 6 for the partial sort
paddle_result = paddle.topk(large_input_tensor, k=1, axis=-1)
numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1)
self.assertTrue(
np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(
np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 7 for the unsorted
paddle_result = paddle.topk(input_tensor, k=2, axis=1, sorted=False)
sort_paddle = numpy_topk(
np.array(paddle_result[0].numpy()), axis=1, k=2)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0]))
def run_static(self, place):
paddle.enable_static()
......@@ -264,14 +284,15 @@ class TestTopKAPI(unittest.TestCase):
self.run_static(place)
def test_errors(self):
paddle.disable_static()
x = paddle.to_tensor([1, 2, 3])
with self.assertRaises(BaseException):
paddle.topk(x, k=-1)
with paddle.fluid.dygraph.guard():
x = paddle.to_tensor([1, 2, 3])
with self.assertRaises(BaseException):
paddle.topk(x, k=-1)
with self.assertRaises(BaseException):
paddle.topk(x, k=0)
with self.assertRaises(BaseException):
paddle.topk(x, k=0)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -74,6 +74,7 @@ class TestViterbiOp(OpTest):
def setUp(self):
self.op_type = "viterbi_decode"
self.python_api = paddle.text.viterbi_decode
self.set_attr()
bz, length, ntags = self.bz, self.len, self.ntags
self.input = np.random.randn(bz, length, ntags).astype(self.dtype)
......@@ -90,7 +91,7 @@ class TestViterbiOp(OpTest):
self.outputs = {'Scores': scores, 'Path': path}
def test_output(self):
self.check_output()
self.check_output(check_eager=True)
class TestViterbiAPI(unittest.TestCase):
......@@ -132,3 +133,8 @@ class TestViterbiAPI(unittest.TestCase):
def test_static_net(self):
for place in self.places:
self.check_static_result(place)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -31,7 +31,7 @@ def YoloBox(x, img_size, attrs):
an_num = int((len(anchors) // 2))
class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample']
downsample = attrs['downsample_ratio']
clip_bbox = attrs['clip_bbox']
scale_x_y = attrs['scale_x_y']
iou_aware = attrs['iou_aware']
......@@ -92,13 +92,14 @@ class TestYoloBoxOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'yolo_box'
self.python_api = paddle.vision.ops.yolo_box
x = np.random.random(self.x_shape).astype('float32')
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
self.attrs = {
'anchors': self.anchors,
'class_num': self.class_num,
'conf_thresh': self.conf_thresh,
'downsample': self.downsample,
'downsample_ratio': self.downsample,
'clip_bbox': self.clip_bbox,
'scale_x_y': self.scale_x_y,
'iou_aware': self.iou_aware,
......
......@@ -28,7 +28,7 @@ from ...tensor import clip
from ...tensor import sum
from ...tensor import sqrt
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
from ...fluid.framework import _varbase_creator
from ...fluid.framework import _varbase_creator, _in_legacy_dygraph, in_dygraph_mode
from ...fluid import dygraph_utils
from ...fluid import layers
......@@ -1616,7 +1616,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
if epsilon > 1. or epsilon < 0.:
raise ValueError("The value of epsilon must be between 0 and 1.")
if in_dynamic_mode():
if paddle.in_dynamic_mode():
return _C_ops.label_smooth(label, prior_dist, 'epsilon', float(epsilon))
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
......
......@@ -37,7 +37,7 @@ from paddle.utils import deprecated
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.framework import core
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
__all__ = []
......@@ -784,11 +784,12 @@ def nll_loss(input,
input_dims))
n = input_shape[0]
c = input_shape[1]
if in_dynamic_mode():
if _non_static_mode():
if input_dims != 2 and input_dims != 4:
input, _ = _C_ops.reshape2(input, None, 'shape', [n, c, 1, -1])
label, _ = _C_ops.reshape2(label, None, 'shape', [n, 1, -1])
out_shape = [n] + input_shape[2:]
out, total_weight = _C_ops.nll_loss(input, label, weight,
'ignore_index', ignore_index,
'reduction', reduction)
......
......@@ -181,7 +181,7 @@ def batch_norm(x,
trainable_statistics = not use_global_stats
if in_dynamic_mode():
# for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "is_test",
not training, "data_layout", data_format, "use_mkldnn", False,
"fuse_with_relu", False, "use_global_stats", use_global_stats,
......
......@@ -1397,7 +1397,10 @@ def histogram(input, bins=100, min=0, max=0, name=None):
result = paddle.histogram(inputs, bins=4, min=0, max=3)
print(result) # [0, 2, 1, 0]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_histogram(input, bins, min, max)
if _in_legacy_dygraph():
return _C_ops.histogram(input, "bins", bins, "min", min, "max", max)
helper = LayerHelper('histogram', **locals())
......
......@@ -536,6 +536,8 @@ def bitwise_and(x, y, out=None, name=None):
res = paddle.bitwise_and(x, y)
print(res) # [0, 2, 1]
"""
if in_dygraph_mode() and out == None:
return _C_ops.final_state_bitwise_and(x, y)
return _bitwise_op(
op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True)
......@@ -562,6 +564,9 @@ def bitwise_or(x, y, out=None, name=None):
res = paddle.bitwise_or(x, y)
print(res) # [-1, -1, -3]
"""
if in_dygraph_mode() and out == None:
return _C_ops.final_state_bitwise_or(x, y)
return _bitwise_op(
op_name="bitwise_or", x=x, y=y, name=name, out=out, binary_op=True)
......@@ -588,6 +593,8 @@ def bitwise_xor(x, y, out=None, name=None):
res = paddle.bitwise_xor(x, y)
print(res) # [-1, -3, -4]
"""
if in_dygraph_mode() and out == None:
return _C_ops.final_state_bitwise_xor(x, y)
return _bitwise_op(
op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True)
......@@ -612,6 +619,8 @@ def bitwise_not(x, out=None, name=None):
res = paddle.bitwise_not(x)
print(res) # [4, 0, -2]
"""
if in_dygraph_mode() and out == None:
return _C_ops.final_state_bitwise_not(x)
return _bitwise_op(
op_name="bitwise_not", x=x, y=None, name=name, out=out, binary_op=False)
......
......@@ -17,7 +17,7 @@ from collections import Counter
from ..static import Variable, device_guard
from ..framework import core
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _in_eager_without_dygraph_check
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _in_eager_without_dygraph_check, _non_static_mode
from ..fluid.layer_helper import LayerHelper
from ..framework import OpProtoHolder, convert_np_dtype_to_dtype_, dygraph_only
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
......@@ -1845,7 +1845,7 @@ def expand_as(x, y, name=None):
np_out = out.numpy()
# [[1, 2, 3], [1, 2, 3]]
"""
if paddle.in_dynamic_mode():
if _non_static_mode():
return _C_ops.expand_as_v2(x, 'target_shape', y.shape)
check_variable_and_dtype(
......
......@@ -2681,7 +2681,9 @@ def isfinite(x, name=None):
out = paddle.tensor.isfinite(x)
print(out) # [False True True False True False False]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_isfinite( x )
if _in_legacy_dygraph():
return _C_ops.isfinite_v2(x)
helper = LayerHelper("isfinite_v2", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isfinite')
......@@ -2709,7 +2711,9 @@ def isinf(x, name=None):
out = paddle.tensor.isinf(x)
print(out) # [ True False False True False False False]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_isinf( x )
if _in_legacy_dygraph():
return _C_ops.isinf_v2(x)
helper = LayerHelper("isinf_v2", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isinf')
......@@ -2737,7 +2741,10 @@ def isnan(x, name=None):
out = paddle.tensor.isnan(x)
print(out) # [False False False False False True True]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_isnan( x )
if _in_legacy_dygraph():
return _C_ops.isnan_v2(x)
helper = LayerHelper("isnan_v2", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isnan')
......@@ -3387,8 +3394,13 @@ def lerp(x, y, weight, name=None):
# out: [5.5., 6., 6.5, 7.]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp')
if isinstance(weight, float):
weight = paddle.to_tensor(weight, dtype=x.dtype)
return _C_ops.final_state_lerp( x, y, weight)
if _in_legacy_dygraph():
if isinstance(weight, float):
weight = paddle.to_tensor(weight, dtype=x.dtype)
return _C_ops.lerp(x, y, weight)
......
......@@ -18,7 +18,7 @@ from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
from ..fluid import layers
from ..framework import core
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
from paddle.common_ops_import import convert_np_dtype_to_dtype_
from paddle.common_ops_import import Variable
from paddle.common_ops_import import VarDesc
......@@ -774,7 +774,10 @@ def masked_select(x, mask, name=None):
#[1.0 5.0 6.0 9.0]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_masked_select(x, mask)
if _in_legacy_dygraph():
return _C_ops.masked_select(x, mask)
helper = LayerHelper("masked_select", **locals())
......@@ -844,8 +847,8 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None):
# [[1 1 0 0]]
"""
if paddle.in_dynamic_mode():
k = k.numpy().item(0) if isinstance(k, Variable) else k
if _non_static_mode():
if axis is None:
out, indices = _C_ops.top_k_v2(x, 'k',
int(k), 'largest', largest, 'sorted',
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from ..nn import Layer
from ..fluid.framework import core, _non_static_mode
from ..fluid.framework import core, _non_static_mode, in_dygraph_mode
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type
from paddle import _C_ops
......@@ -58,6 +58,10 @@ def viterbi_decode(potentials,
transition = paddle.rand((num_tags, num_tags), dtype='float32')
scores, path = paddle.text.viterbi_decode(emission, transition, length, False) # scores: [3.37089300, 1.56825531], path: [[1, 0, 0], [1, 1, 0]]
"""
if in_dygraph_mode():
return _C_ops.final_state_viterbi_decode(potentials, transition_params,
lengths, include_bos_eos_tag)
if _non_static_mode():
return _C_ops.viterbi_decode(potentials, transition_params, lengths,
'include_bos_eos_tag', include_bos_eos_tag)
......
......@@ -547,6 +547,15 @@
func : hard_sigmoid
backward : hard_sigmoid_grad
# histogram
- api : histogram
args : (Tensor x, int64_t bins, int min, int max)
output : Tensor
infer_meta :
func : HistogramInferMeta
kernel :
func : histogram
- api : huber_loss
args : (Tensor input, Tensor label, float delta)
output : Tensor(out), Tensor(residual)
......
......@@ -19,7 +19,7 @@ from ..fluid import core, layers
from ..fluid.layers import nn, utils
from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D
from ..fluid.initializer import Normal
from ..fluid.framework import _non_static_mode
from ..fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.common_ops_import import *
from paddle import _C_ops
......@@ -377,6 +377,12 @@ def yolo_box(x,
clip_bbox=True,
scale_x_y=1.)
"""
if in_dygraph_mode():
boxes, scores = _C_ops.final_state_yolo_box(
x, img_size, anchors, class_num, conf_thresh, downsample_ratio,
clip_bbox, scale_x_y, iou_aware, iou_aware_factor)
return boxes, scores
if _non_static_mode():
boxes, scores = _C_ops.yolo_box(
x, img_size, 'anchors', anchors, 'class_num', class_num,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册