未验证 提交 98303291 编写于 作者: H hong 提交者: GitHub

Add basic yaml backward (#40751)

* fix error; test=develop

* update

* close some yaml

* fix backward attrite error; test=develop

* add div test

* polish code; test=develop

* update

* update

* fix bug

* update bitwise code; test=develop

* update

* update

* fix some bug

* update

* revert cmakelist

* fix optional bug;

* fix bug

* fix bug;

* add backward test

* open bn

* update

* update

* revert eager_gen

* polish code

* fix topk error

* update

* update

* fix bug;

* move label smooth, nll loss

* revert topk

* fix topk label smooth bug;

* remove batch_norm

* remove topk

* change flip infer meta

* fix flip bug

* update yaml

* close abs

* fix histogram bug

* fix histogram bug

* add abs

* fix histogram kernel

* remove expand
上级 2f41f389
...@@ -68,9 +68,9 @@ void IndexSampleGradInner(const Context& context, ...@@ -68,9 +68,9 @@ void IndexSampleGradInner(const Context& context,
template <typename T, typename Context> template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx, void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& index, const DenseTensor& index,
const DenseTensor& out_grad,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto index_type = index.dtype(); auto index_type = index.dtype();
bool index_type_match = bool index_type_match =
......
...@@ -21,9 +21,9 @@ namespace phi { ...@@ -21,9 +21,9 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx, void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& mask, const DenseTensor& mask,
const DenseTensor& out_grad,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto* mask_data = mask.data<bool>(); auto* mask_data = mask.data<bool>();
auto* input_data = out_grad.data<T>(); auto* input_data = out_grad.data<T>();
......
...@@ -121,8 +121,8 @@ template <typename T, typename Context> ...@@ -121,8 +121,8 @@ template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx, void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& labels, const DenseTensor& labels,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight, paddle::optional<const DenseTensor&> weight,
const DenseTensor& total_weight,
const DenseTensor& d_out, const DenseTensor& d_out,
int64_t ignore_index, int64_t ignore_index,
const std::string& reduction, const std::string& reduction,
......
...@@ -51,17 +51,17 @@ static void FullTopKAssign(const Type& input_height, ...@@ -51,17 +51,17 @@ static void FullTopKAssign(const Type& input_height,
template <typename T, typename Context> template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx, void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& indices, const DenseTensor& indices,
int k, const DenseTensor& out_grad,
const Scalar& k_scalar,
int axis, int axis,
bool largest, bool largest,
bool sorted, bool sorted,
DenseTensor* x_grad) { DenseTensor* x_grad) {
const auto& in_dims = x.dims(); const auto& in_dims = x.dims();
const auto& out_dims = indices.dims(); const auto& out_dims = indices.dims();
int k = k_scalar.to<int>();
// axis < 0, get the real axis // axis < 0, get the real axis
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
......
...@@ -36,7 +36,7 @@ void LimitGridDim(const Context& ctx, dim3* grid_dim) { ...@@ -36,7 +36,7 @@ void LimitGridDim(const Context& ctx, dim3* grid_dim) {
#define PREDEFINED_BLOCK_SIZE_X 512 #define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024 #define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
}; } // namespace
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void IndexSampleGrad(const IndexT* index, __global__ void IndexSampleGrad(const IndexT* index,
...@@ -67,9 +67,9 @@ __global__ void IndexSampleGrad(const IndexT* index, ...@@ -67,9 +67,9 @@ __global__ void IndexSampleGrad(const IndexT* index,
template <typename T, typename Context> template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx, void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& index, const DenseTensor& index,
const DenseTensor& out_grad,
DenseTensor* x_grad) { DenseTensor* x_grad) {
const T* output_grad_data = out_grad.data<T>(); const T* output_grad_data = out_grad.data<T>();
T* input_grad_data = ctx.template Alloc<T>(x_grad); T* input_grad_data = ctx.template Alloc<T>(x_grad);
......
...@@ -44,9 +44,9 @@ struct MaskedSelectGradFunctor { ...@@ -44,9 +44,9 @@ struct MaskedSelectGradFunctor {
template <typename T, typename Context> template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx, void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& mask, const DenseTensor& mask,
const DenseTensor& out_grad,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto mask_size = mask.numel(); auto mask_size = mask.numel();
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
......
...@@ -23,8 +23,8 @@ template <typename T, typename Context> ...@@ -23,8 +23,8 @@ template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx, void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& labels, const DenseTensor& labels,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight, paddle::optional<const DenseTensor&> weight,
const DenseTensor& total_weight,
const DenseTensor& dout, const DenseTensor& dout,
int64_t ignore_index, int64_t ignore_index,
const std::string& reduction, const std::string& reduction,
......
...@@ -25,10 +25,10 @@ namespace ops = paddle::operators; ...@@ -25,10 +25,10 @@ namespace ops = paddle::operators;
template <typename T, typename Context> template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx, void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& indices, const DenseTensor& indices,
int k, const DenseTensor& out_grad,
const Scalar& k_scalar,
int axis, int axis,
bool largest, bool largest,
bool sorted, bool sorted,
...@@ -36,6 +36,8 @@ void TopkGradKernel(const Context& dev_ctx, ...@@ -36,6 +36,8 @@ void TopkGradKernel(const Context& dev_ctx,
const auto& in_dims = x.dims(); const auto& in_dims = x.dims();
const auto& out_dims = indices.dims(); const auto& out_dims = indices.dims();
int k = k_scalar.to<int>();
// get the real the axis and the k // get the real the axis and the k
if (axis < 0) { if (axis < 0) {
axis += in_dims.size(); axis += in_dims.size();
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void HistogramSelectKernel(const Context& dev_ctx, void HistogramKernel(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
int64_t bins, int64_t bins,
int min, int min,
int max, int max,
DenseTensor* out); DenseTensor* output);
} // namespace phi } // namespace phi
...@@ -20,9 +20,9 @@ namespace phi { ...@@ -20,9 +20,9 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx, void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& index, const DenseTensor& index,
const DenseTensor& out_grad,
DenseTensor* in_grad); DenseTensor* in_grad);
} // namespace phi } // namespace phi
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx, void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& mask, const DenseTensor& mask,
const DenseTensor& out_grad,
DenseTensor* x_grad); DenseTensor* x_grad);
} // namspace phi } // namspace phi
...@@ -22,8 +22,8 @@ template <typename T, typename Context> ...@@ -22,8 +22,8 @@ template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx, void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& label, const DenseTensor& label,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight, paddle::optional<const DenseTensor&> weight,
const DenseTensor& total_weight,
const DenseTensor& d_out, const DenseTensor& d_out,
int64_t ignore_index, int64_t ignore_index,
const std::string& reduction, const std::string& reduction,
......
...@@ -14,16 +14,17 @@ ...@@ -14,16 +14,17 @@
#pragma once #pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx, void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& indices, const DenseTensor& indices,
int k, const DenseTensor& out_grad,
const Scalar& k,
int axis, int axis,
bool largest, bool largest,
bool sorted, bool sorted,
......
...@@ -19,7 +19,7 @@ namespace phi { ...@@ -19,7 +19,7 @@ namespace phi {
KernelSignature IndexSampleGradOpArgumentMapping( KernelSignature IndexSampleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("index_sample_grad", return KernelSignature("index_sample_grad",
{GradVarName("Out"), "X", "Index"}, {"X", "Index", GradVarName("Out")},
{}, {},
{GradVarName("X")}); {GradVarName("X")});
} }
......
...@@ -24,7 +24,7 @@ KernelSignature MaskedSelectOpArgumentMapping( ...@@ -24,7 +24,7 @@ KernelSignature MaskedSelectOpArgumentMapping(
KernelSignature MaskedSelectGradOpArgumentMapping( KernelSignature MaskedSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("masked_select_grad", return KernelSignature("masked_select_grad",
{GradVarName("Y"), "X", "Mask"}, {"X", "Mask", GradVarName("Y")},
{}, {},
{GradVarName("X")}); {GradVarName("X")});
} }
......
...@@ -29,7 +29,7 @@ KernelSignature NllLossGradOpArgumentMapping( ...@@ -29,7 +29,7 @@ KernelSignature NllLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"nll_loss_grad", "nll_loss_grad",
{"X", "Label", "Total_weight", "Weight", GradVarName("Out")}, {"X", "Label", "Weight", "Total_weight", GradVarName("Out")},
{"ignore_index", "reduction"}, {"ignore_index", "reduction"},
{GradVarName("X")}); {GradVarName("X")});
} }
......
...@@ -29,7 +29,7 @@ KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -29,7 +29,7 @@ KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("top_k_grad", return KernelSignature("top_k_grad",
{GradVarName("Out"), "X", "Indices"}, {"X", "Indices", GradVarName("Out")},
{"k", "axis", "largest", "sorted"}, {"k", "axis", "largest", "sorted"},
{GradVarName("X")}); {GradVarName("X")});
} }
......
...@@ -12529,6 +12529,9 @@ def logical_and(x, y, out=None, name=None): ...@@ -12529,6 +12529,9 @@ def logical_and(x, y, out=None, name=None):
res = paddle.logical_and(x, y) res = paddle.logical_and(x, y)
print(res) # [True False True False] print(res) # [True False True False]
""" """
if in_dygraph_mode():
return _C_ops.final_state_logical_and(x, y)
return _logical_op( return _logical_op(
op_name="logical_and", x=x, y=y, name=name, out=out, binary_op=True) op_name="logical_and", x=x, y=y, name=name, out=out, binary_op=True)
...@@ -12568,6 +12571,8 @@ def logical_or(x, y, out=None, name=None): ...@@ -12568,6 +12571,8 @@ def logical_or(x, y, out=None, name=None):
res = paddle.logical_or(x, y) res = paddle.logical_or(x, y)
print(res) # [[ True True] [ True False]] print(res) # [[ True True] [ True False]]
""" """
if in_dygraph_mode():
return _C_ops.final_state_logical_or(x, y)
return _logical_op( return _logical_op(
op_name="logical_or", x=x, y=y, name=name, out=out, binary_op=True) op_name="logical_or", x=x, y=y, name=name, out=out, binary_op=True)
...@@ -12607,6 +12612,9 @@ def logical_xor(x, y, out=None, name=None): ...@@ -12607,6 +12612,9 @@ def logical_xor(x, y, out=None, name=None):
res = paddle.logical_xor(x, y) res = paddle.logical_xor(x, y)
print(res) # [[False, True], [ True, False]] print(res) # [[False, True], [ True, False]]
""" """
if in_dygraph_mode():
return _C_ops.final_state_logical_xor(x, y)
return _logical_op( return _logical_op(
op_name="logical_xor", x=x, y=y, name=name, out=out, binary_op=True) op_name="logical_xor", x=x, y=y, name=name, out=out, binary_op=True)
...@@ -12639,7 +12647,8 @@ def logical_not(x, out=None, name=None): ...@@ -12639,7 +12647,8 @@ def logical_not(x, out=None, name=None):
res = paddle.logical_not(x) res = paddle.logical_not(x)
print(res) # [False True False True] print(res) # [False True False True]
""" """
if in_dygraph_mode():
return _C_ops.final_state_logical_not(x)
return _logical_op( return _logical_op(
op_name="logical_not", x=x, y=None, name=name, out=out, binary_op=False) op_name="logical_not", x=x, y=None, name=name, out=out, binary_op=False)
......
...@@ -19,7 +19,7 @@ import paddle.fluid.core as core ...@@ -19,7 +19,7 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest, _set_use_system_allocator from op_test import OpTest, _set_use_system_allocator
from paddle.fluid.framework import grad_var_name from paddle.fluid.framework import grad_var_name, _test_eager_guard
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
import paddle import paddle
...@@ -46,32 +46,32 @@ class TestBatchNorm(unittest.TestCase): ...@@ -46,32 +46,32 @@ class TestBatchNorm(unittest.TestCase):
def error1d_dataformat(): def error1d_dataformat():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm1d = paddle.nn.BatchNorm1D(1, data_format='NCDHW') batch_norm1d = paddle.nn.BatchNorm1D(1, data_format='NCDHW')
batch_norm1d(fluid.dygraph.to_variable(x_data_4)) batch_norm1d(paddle.to_tensor(x_data_4))
def error2d_dataformat(): def error2d_dataformat():
x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32') x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32')
batch_norm2d = paddle.nn.BatchNorm2D(1, data_format='NCDHW') batch_norm2d = paddle.nn.BatchNorm2D(1, data_format='NCDHW')
batch_norm2d(fluid.dygraph.to_variable(x_data_3)) batch_norm2d(paddle.to_tensor(x_data_3))
def error3d_dataformat(): def error3d_dataformat():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm3d = paddle.nn.BatchNorm3D(1, data_format='NCL') batch_norm3d = paddle.nn.BatchNorm3D(1, data_format='NCL')
batch_norm3d(fluid.dygraph.to_variable(x_data_4)) batch_norm3d(paddle.to_tensor(x_data_4))
def error1d(): def error1d():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm1d = paddle.nn.BatchNorm1D(1) batch_norm1d = paddle.nn.BatchNorm1D(1)
batch_norm1d(fluid.dygraph.to_variable(x_data_4)) batch_norm1d(paddle.to_tensor(x_data_4))
def error2d(): def error2d():
x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32') x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32')
batch_norm2d = paddle.nn.BatchNorm2D(1) batch_norm2d = paddle.nn.BatchNorm2D(1)
batch_norm2d(fluid.dygraph.to_variable(x_data_3)) batch_norm2d(paddle.to_tensor(x_data_3))
def error3d(): def error3d():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm3d = paddle.nn.BatchNorm3D(1) batch_norm3d = paddle.nn.BatchNorm3D(1)
batch_norm3d(fluid.dygraph.to_variable(x_data_4)) batch_norm3d(paddle.to_tensor(x_data_4))
with fluid.dygraph.guard(p): with fluid.dygraph.guard(p):
self.assertRaises(ValueError, error1d) self.assertRaises(ValueError, error1d)
...@@ -94,13 +94,18 @@ class TestBatchNorm(unittest.TestCase): ...@@ -94,13 +94,18 @@ class TestBatchNorm(unittest.TestCase):
shape[1], shape[1],
is_test=is_test, is_test=is_test,
trainable_statistics=trainable_statistics) trainable_statistics=trainable_statistics)
y = bn(fluid.dygraph.to_variable(x)) y = bn(paddle.to_tensor(x))
return y.numpy() return y.numpy()
def compute_v2(x): def compute_v2(x):
with fluid.dygraph.guard(p): with fluid.dygraph.guard(p):
bn = paddle.nn.BatchNorm2D(shape[1]) bn = paddle.nn.BatchNorm2D(shape[1])
y = bn(fluid.dygraph.to_variable(x)) y = bn(paddle.to_tensor(x))
with _test_eager_guard():
bn = paddle.nn.BatchNorm2D(shape[1])
eag_y = bn(paddle.to_tensor(x))
assert np.allclose(eag_y.numpy(), y.numpy())
return y.numpy() return y.numpy()
def compute_v3(x, is_test, trainable_statistics): def compute_v3(x, is_test, trainable_statistics):
...@@ -115,14 +120,14 @@ class TestBatchNorm(unittest.TestCase): ...@@ -115,14 +120,14 @@ class TestBatchNorm(unittest.TestCase):
initializer=fluid.initializer.Constant(0.0), initializer=fluid.initializer.Constant(0.0),
trainable=False), trainable=False),
trainable_statistics=trainable_statistics) trainable_statistics=trainable_statistics)
y = bn(fluid.dygraph.to_variable(x)) y = bn(paddle.to_tensor(x))
return y.numpy() return y.numpy()
def compute_v4(x): def compute_v4(x):
with fluid.dygraph.guard(p): with fluid.dygraph.guard(p):
bn = paddle.nn.BatchNorm2D( bn = paddle.nn.BatchNorm2D(
shape[1], weight_attr=False, bias_attr=False) shape[1], weight_attr=False, bias_attr=False)
y = bn(fluid.dygraph.to_variable(x)) y = bn(paddle.to_tensor(x))
return y.numpy() return y.numpy()
x = np.random.randn(*shape).astype("float32") x = np.random.randn(*shape).astype("float32")
......
...@@ -32,6 +32,7 @@ class ElementwiseDivOp(OpTest): ...@@ -32,6 +32,7 @@ class ElementwiseDivOp(OpTest):
'X': np.random.random((32,84)).astype("float32"), 'X': np.random.random((32,84)).astype("float32"),
'Y': np.random.random((32,84)).astype("float32") 'Y': np.random.random((32,84)).astype("float32")
""" """
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype), 'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) 'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
...@@ -39,7 +40,7 @@ class ElementwiseDivOp(OpTest): ...@@ -39,7 +40,7 @@ class ElementwiseDivOp(OpTest):
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
def check_eager(self): def check_eager(self):
return (self.use_mkldnn == False and self.axis == -1) return (not hasattr(self, "attrs") or (self.attrs["axis"] != -1))
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -65,6 +66,7 @@ class ElementwiseDivOp(OpTest): ...@@ -65,6 +66,7 @@ class ElementwiseDivOp(OpTest):
class TestElementwiseDivOpBF16(OpTest): class TestElementwiseDivOpBF16(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.dtype = np.uint16 self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32) x = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32)
...@@ -100,6 +102,7 @@ class TestElementwiseDivOpBF16(OpTest): ...@@ -100,6 +102,7 @@ class TestElementwiseDivOpBF16(OpTest):
class TestElementwiseDivOp_scalar(ElementwiseDivOp): class TestElementwiseDivOp_scalar(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [20, 3, 4]).astype(np.float64), 'X': np.random.uniform(0.1, 1, [20, 3, 4]).astype(np.float64),
'Y': np.random.uniform(0.1, 1, [1]).astype(np.float64) 'Y': np.random.uniform(0.1, 1, [1]).astype(np.float64)
...@@ -110,6 +113,7 @@ class TestElementwiseDivOp_scalar(ElementwiseDivOp): ...@@ -110,6 +113,7 @@ class TestElementwiseDivOp_scalar(ElementwiseDivOp):
class TestElementwiseDivOp_Vector(ElementwiseDivOp): class TestElementwiseDivOp_Vector(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [100]).astype("float64"), 'X': np.random.uniform(0.1, 1, [100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64") 'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
...@@ -120,6 +124,7 @@ class TestElementwiseDivOp_Vector(ElementwiseDivOp): ...@@ -120,6 +124,7 @@ class TestElementwiseDivOp_Vector(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp): class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [100, 3, 4]).astype("float64"), 'X': np.random.uniform(0.1, 1, [100, 3, 4]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64") 'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
...@@ -135,6 +140,7 @@ class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp): ...@@ -135,6 +140,7 @@ class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp): class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 100, 4]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 100, 4]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64") 'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
...@@ -150,6 +156,7 @@ class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp): ...@@ -150,6 +156,7 @@ class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp): class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64") 'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
...@@ -164,6 +171,7 @@ class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp): ...@@ -164,6 +171,7 @@ class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp): class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 10, 12, 5]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 10, 12, 5]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [10, 12]).astype("float64") 'Y': np.random.uniform(0.1, 1, [10, 12]).astype("float64")
...@@ -179,6 +187,7 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp): ...@@ -179,6 +187,7 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp): class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 50]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 3, 50]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 1, 50]).astype("float64") 'Y': np.random.uniform(0.1, 1, [2, 1, 50]).astype("float64")
...@@ -189,6 +198,7 @@ class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp): ...@@ -189,6 +198,7 @@ class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp):
class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp): class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4, 20]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 3, 4, 20]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 3, 1, 20]).astype("float64") 'Y': np.random.uniform(0.1, 1, [2, 3, 1, 20]).astype("float64")
...@@ -199,6 +209,7 @@ class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp): ...@@ -199,6 +209,7 @@ class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp): class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [1, 1, 100]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [1, 1, 100]).astype("float64"),
...@@ -209,6 +220,7 @@ class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp): ...@@ -209,6 +220,7 @@ class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp):
class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp): class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [30, 3, 1, 5]).astype("float64"), 'X': np.random.uniform(0.1, 1, [30, 3, 1, 5]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [30, 1, 4, 1]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [30, 1, 4, 1]).astype("float64"),
...@@ -219,6 +231,7 @@ class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp): ...@@ -219,6 +231,7 @@ class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp):
class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp): class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 12]).astype("float64"), 'X': np.random.uniform(0.1, 1, [10, 12]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 3, 10, 12]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [2, 3, 10, 12]).astype("float64"),
...@@ -232,6 +245,7 @@ class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp): ...@@ -232,6 +245,7 @@ class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp):
class TestElementwiseDivOp_INT(OpTest): class TestElementwiseDivOp_INT(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.dtype = np.int32 self.dtype = np.int32
self.init_dtype() self.init_dtype()
self.inputs = { self.inputs = {
...@@ -304,6 +318,7 @@ class TestDivideOp(unittest.TestCase): ...@@ -304,6 +318,7 @@ class TestDivideOp(unittest.TestCase):
class TestComplexElementwiseDivOp(OpTest): class TestComplexElementwiseDivOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.init_base_dtype() self.init_base_dtype()
self.init_input_output() self.init_input_output()
self.init_grad_input_output() self.init_grad_input_output()
...@@ -334,7 +349,7 @@ class TestComplexElementwiseDivOp(OpTest): ...@@ -334,7 +349,7 @@ class TestComplexElementwiseDivOp(OpTest):
self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y) self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y)
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=False)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(
......
...@@ -24,6 +24,7 @@ import paddle.fluid as fluid ...@@ -24,6 +24,7 @@ import paddle.fluid as fluid
class TestExpandAsOpRank1(OpTest): class TestExpandAsOpRank1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand_as_v2" self.op_type = "expand_as_v2"
self.python_api = paddle.expand_as
x = np.random.rand(100).astype("float64") x = np.random.rand(100).astype("float64")
target_tensor = np.random.rand(2, 100).astype("float64") target_tensor = np.random.rand(2, 100).astype("float64")
self.inputs = {'X': x} self.inputs = {'X': x}
......
...@@ -21,6 +21,7 @@ import paddle.fluid as fluid ...@@ -21,6 +21,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard
class TestHistogramOpAPI(unittest.TestCase): class TestHistogramOpAPI(unittest.TestCase):
...@@ -57,6 +58,15 @@ class TestHistogramOpAPI(unittest.TestCase): ...@@ -57,6 +58,15 @@ class TestHistogramOpAPI(unittest.TestCase):
(actual.numpy() == expected).all(), (actual.numpy() == expected).all(),
msg='histogram output is wrong, out =' + str(actual.numpy())) msg='histogram output is wrong, out =' + str(actual.numpy()))
with _test_eager_guard():
inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64)
inputs = paddle.to_tensor(inputs_np)
actual = paddle.histogram(inputs, bins=5, min=1, max=5)
self.assertTrue(
(actual.numpy() == expected).all(),
msg='histogram output is wrong, out =' +
str(actual.numpy()))
class TestHistogramOpError(unittest.TestCase): class TestHistogramOpError(unittest.TestCase):
"""Test histogram op error.""" """Test histogram op error."""
...@@ -118,6 +128,7 @@ class TestHistogramOp(OpTest): ...@@ -118,6 +128,7 @@ class TestHistogramOp(OpTest):
self.op_type = "histogram" self.op_type = "histogram"
self.init_test_case() self.init_test_case()
np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape) np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape)
self.python_api = paddle.histogram
self.inputs = {"X": np_input} self.inputs = {"X": np_input}
self.init_attrs() self.init_attrs()
Out, _ = np.histogram( Out, _ = np.histogram(
...@@ -134,7 +145,7 @@ class TestHistogramOp(OpTest): ...@@ -134,7 +145,7 @@ class TestHistogramOp(OpTest):
self.attrs = {"bins": self.bins, "min": self.min, "max": self.max} self.attrs = {"bins": self.bins, "min": self.min, "max": self.max}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -40,10 +40,10 @@ class TestIndexSampleOp(OpTest): ...@@ -40,10 +40,10 @@ class TestIndexSampleOp(OpTest):
self.outputs = {'Out': out} self.outputs = {'Out': out}
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=False) self.check_grad(['X'], 'Out', check_eager=True)
def config(self): def config(self):
""" """
......
...@@ -16,6 +16,7 @@ import paddle ...@@ -16,6 +16,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.framework import _test_eager_guard
def run_static(x_np, dtype, op_str, use_gpu=False): def run_static(x_np, dtype, op_str, use_gpu=False):
...@@ -46,6 +47,18 @@ def run_dygraph(x_np, op_str, use_gpu=True): ...@@ -46,6 +47,18 @@ def run_dygraph(x_np, op_str, use_gpu=True):
return dygraph_result return dygraph_result
def run_eager(x_np, op_str, use_gpu=True):
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
place = paddle.CPUPlace()
if use_gpu and fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
x = paddle.to_tensor(x_np)
dygraph_result = getattr(paddle.tensor, op_str)(x)
return dygraph_result
def np_data_generator(low, high, np_shape, type, sv_list, op_str, *args, def np_data_generator(low, high, np_shape, type, sv_list, op_str, *args,
**kwargs): **kwargs):
x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type))
...@@ -107,8 +120,10 @@ def test(test_case, op_str, use_gpu=False): ...@@ -107,8 +120,10 @@ def test(test_case, op_str, use_gpu=False):
x_np, result_np = np_data_generator(**meta_data) x_np, result_np = np_data_generator(**meta_data)
static_result = run_static(x_np, meta_data['type'], op_str, use_gpu) static_result = run_static(x_np, meta_data['type'], op_str, use_gpu)
dygraph_result = run_dygraph(x_np, op_str, use_gpu) dygraph_result = run_dygraph(x_np, op_str, use_gpu)
eager_result = run_eager(x_np, op_str, use_gpu)
test_case.assertTrue((static_result == result_np).all()) test_case.assertTrue((static_result == result_np).all())
test_case.assertTrue((dygraph_result.numpy() == result_np).all()) test_case.assertTrue((dygraph_result.numpy() == result_np).all())
test_case.assertTrue((eager_result.numpy() == result_np).all())
class TestCPUNormal(unittest.TestCase): class TestCPUNormal(unittest.TestCase):
...@@ -158,4 +173,5 @@ class TestError(unittest.TestCase): ...@@ -158,4 +173,5 @@ class TestError(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -27,6 +27,7 @@ np.random.seed(0) ...@@ -27,6 +27,7 @@ np.random.seed(0)
class TestLerp(OpTest): class TestLerp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "lerp" self.op_type = "lerp"
self.python_api = paddle.lerp
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
x = np.arange(1., 101.).astype(self.dtype).reshape(self.shape) x = np.arange(1., 101.).astype(self.dtype).reshape(self.shape)
...@@ -42,10 +43,10 @@ class TestLerp(OpTest): ...@@ -42,10 +43,10 @@ class TestLerp(OpTest):
self.shape = [100] self.shape = [100]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out', check_eager=True)
class TestLerpWithDim2(TestLerp): class TestLerpWithDim2(TestLerp):
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.static import Program, program_guard from paddle.static import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
SUPPORTED_DTYPES = [ SUPPORTED_DTYPES = [
bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64 bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64
...@@ -144,6 +145,22 @@ def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True): ...@@ -144,6 +145,22 @@ def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True):
return dygraph_result return dygraph_result
def run_eager(x_np, y_np, op_str, use_gpu=False, binary_op=True):
place = paddle.CPUPlace()
if use_gpu and fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
paddle.disable_static(place)
with _test_eager_guard():
op = getattr(paddle, op_str)
x = paddle.to_tensor(x_np, dtype=x_np.dtype)
if not binary_op:
dygraph_result = op(x)
else:
y = paddle.to_tensor(y_np, dtype=y_np.dtype)
dygraph_result = op(x, y)
return dygraph_result
def np_data_generator(np_shape, dtype, *args, **kwargs): def np_data_generator(np_shape, dtype, *args, **kwargs):
if dtype == bool: if dtype == bool:
return np.random.choice(a=[True, False], size=np_shape).astype(bool) return np.random.choice(a=[True, False], size=np_shape).astype(bool)
...@@ -174,6 +191,7 @@ def test(unit_test, use_gpu=False, test_error=False): ...@@ -174,6 +191,7 @@ def test(unit_test, use_gpu=False, test_error=False):
continue continue
static_result = run_static(**meta_data) static_result = run_static(**meta_data)
dygraph_result = run_dygraph(**meta_data) dygraph_result = run_dygraph(**meta_data)
eager_result = run_eager(**meta_data)
if meta_data['binary_op']: if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np']) np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else: else:
...@@ -181,6 +199,7 @@ def test(unit_test, use_gpu=False, test_error=False): ...@@ -181,6 +199,7 @@ def test(unit_test, use_gpu=False, test_error=False):
unit_test.assertTrue((static_result == np_result).all()) unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all( unit_test.assertTrue((dygraph_result.numpy() == np_result).all(
)) ))
unit_test.assertTrue((eager_result.numpy() == np_result).all())
def test_type_error(unit_test, use_gpu, type_str_map): def test_type_error(unit_test, use_gpu, type_str_map):
...@@ -259,4 +278,5 @@ class TestCUDA(unittest.TestCase): ...@@ -259,4 +278,5 @@ class TestCUDA(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -33,6 +33,7 @@ class TestMaskedSelectOp(OpTest): ...@@ -33,6 +33,7 @@ class TestMaskedSelectOp(OpTest):
def setUp(self): def setUp(self):
self.init() self.init()
self.op_type = "masked_select" self.op_type = "masked_select"
self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype("float64") x = np.random.random(self.shape).astype("float64")
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool)) mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(x, mask) out = np_masked_select(x, mask)
...@@ -40,10 +41,10 @@ class TestMaskedSelectOp(OpTest): ...@@ -40,10 +41,10 @@ class TestMaskedSelectOp(OpTest):
self.outputs = {'Y': out} self.outputs = {'Y': out}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y') self.check_grad(['X'], 'Y', check_eager=True)
def init(self): def init(self):
self.shape = (50, 3) self.shape = (50, 3)
...@@ -121,4 +122,5 @@ class TestMaskedSelectError(unittest.TestCase): ...@@ -121,4 +122,5 @@ class TestMaskedSelectError(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ import paddle.fluid as fluid ...@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import numpy as np import numpy as np
import unittest import unittest
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard
def nll_loss_1d(logs, targets, weight=None, reduction='mean', def nll_loss_1d(logs, targets, weight=None, reduction='mean',
...@@ -97,14 +98,21 @@ class TestNLLLoss(unittest.TestCase): ...@@ -97,14 +98,21 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss() nll_loss = paddle.nn.loss.NLLLoss()
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
with fluid.dygraph.guard():
with _test_eager_guard():
nll_loss = paddle.nn.loss.NLLLoss()
eager_res = nll_loss(
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
eager_result = eager_res.numpy()
expected = nll_loss_1d(input_np, label_np)[0] expected = nll_loss_1d(input_np, label_np)[0]
self.assertTrue(np.allclose(static_result, expected)) self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result)) self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected)) self.assertTrue(np.allclose(dy_result, expected))
self.assertTrue(np.allclose(eager_result, expected))
def test_NLLLoss_1D_sum(self): def test_NLLLoss_1D_sum(self):
np.random.seed(200) np.random.seed(200)
...@@ -132,14 +140,24 @@ class TestNLLLoss(unittest.TestCase): ...@@ -132,14 +140,24 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
with _test_eager_guard():
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
in_t = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
in_t.stop_gradient = False
eager_res = nll_loss(in_t, label)
eager_result = eager_res.numpy()
loss = eager_res.sum()
loss.backward()
expected = nll_loss_1d(input_np, label_np, reduction='sum')[0] expected = nll_loss_1d(input_np, label_np, reduction='sum')[0]
self.assertTrue(np.allclose(static_result, expected)) self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result)) self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected)) self.assertTrue(np.allclose(dy_result, expected))
self.assertTrue(np.allclose(eager_result, expected))
def test_NLLLoss_1D_with_weight_mean(self): def test_NLLLoss_1D_with_weight_mean(self):
np.random.seed(200) np.random.seed(200)
...@@ -170,16 +188,26 @@ class TestNLLLoss(unittest.TestCase): ...@@ -170,16 +188,26 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np)) weight=paddle.to_tensor(weight_np))
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
with _test_eager_guard():
nll_loss = paddle.nn.loss.NLLLoss(
weight=paddle.to_tensor(weight_np))
eager_res = nll_loss(
paddle.to_tensor(input_np), paddle.to_tensor(label_np))
loss = eager_res.sum()
loss.backward()
eager_result = eager_res.numpy()
expected = nll_loss_1d(input_np, label_np, weight=weight_np)[0] expected = nll_loss_1d(input_np, label_np, weight=weight_np)[0]
self.assertTrue(np.allclose(static_result, expected)) self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result)) self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected)) self.assertTrue(np.allclose(dy_result, expected))
self.assertTrue(np.allclose(eager_result, expected))
def test_NLLLoss_1D_with_weight_sum(self): def test_NLLLoss_1D_with_weight_sum(self):
np.random.seed(200) np.random.seed(200)
...@@ -210,10 +238,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -210,10 +238,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum') weight=paddle.to_tensor(weight_np), reduction='sum')
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
expected = nll_loss_1d( expected = nll_loss_1d(
input_np, label_np, weight=weight_np, reduction='sum')[0] input_np, label_np, weight=weight_np, reduction='sum')[0]
...@@ -249,10 +276,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -249,10 +276,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np)) weight=paddle.to_tensor(weight_np))
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
expected = nll_loss_1d(input_np, label_np, weight=weight_np)[0] expected = nll_loss_1d(input_np, label_np, weight=weight_np)[0]
...@@ -287,10 +313,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -287,10 +313,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none') weight=paddle.to_tensor(weight_np), reduction='none')
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
expected = nll_loss_1d( expected = nll_loss_1d(
input_np, label_np, weight=weight_np, reduction='none') input_np, label_np, weight=weight_np, reduction='none')
...@@ -326,8 +351,7 @@ class TestNLLLoss(unittest.TestCase): ...@@ -326,8 +351,7 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss() nll_loss = paddle.nn.loss.NLLLoss()
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
expected = nll_loss_2d(input_np, label_np)[0] expected = nll_loss_2d(input_np, label_np)[0]
...@@ -363,8 +387,7 @@ class TestNLLLoss(unittest.TestCase): ...@@ -363,8 +387,7 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
expected = nll_loss_2d(input_np, label_np, reduction='sum')[0] expected = nll_loss_2d(input_np, label_np, reduction='sum')[0]
...@@ -404,10 +427,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -404,10 +427,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np)) weight=paddle.to_tensor(weight_np))
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
expected = nll_loss_2d(input_np, label_np, weight=weight_np)[0] expected = nll_loss_2d(input_np, label_np, weight=weight_np)[0]
...@@ -445,10 +467,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -445,10 +467,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np)) weight=paddle.to_tensor(weight_np))
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
expected = nll_loss_2d(input_np, label_np, weight=weight_np)[0] expected = nll_loss_2d(input_np, label_np, weight=weight_np)[0]
...@@ -487,10 +508,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -487,10 +508,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum') weight=paddle.to_tensor(weight_np), reduction='sum')
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
expected = nll_loss_2d( expected = nll_loss_2d(
...@@ -527,8 +547,7 @@ class TestNLLLoss(unittest.TestCase): ...@@ -527,8 +547,7 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss() nll_loss = paddle.nn.loss.NLLLoss()
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
input_shape = input_np.shape input_shape = input_np.shape
...@@ -572,10 +591,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -572,10 +591,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np)) weight=paddle.to_tensor(weight_np))
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
input_shape = input_np.shape input_shape = input_np.shape
...@@ -620,10 +638,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -620,10 +638,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum') weight=paddle.to_tensor(weight_np), reduction='sum')
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
input_shape = input_np.shape input_shape = input_np.shape
...@@ -671,10 +688,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -671,10 +688,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none') weight=paddle.to_tensor(weight_np), reduction='none')
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
input_shape = input_np.shape input_shape = input_np.shape
...@@ -721,10 +737,9 @@ class TestNLLLoss(unittest.TestCase): ...@@ -721,10 +737,9 @@ class TestNLLLoss(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss( nll_loss = paddle.nn.loss.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none') weight=paddle.to_tensor(weight_np), reduction='none')
dy_res = nll_loss( dy_res = nll_loss(
fluid.dygraph.to_variable(input_np), paddle.to_tensor(input_np), paddle.to_tensor(label_np))
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy() dy_result = dy_res.numpy()
input_shape = input_np.shape input_shape = input_np.shape
...@@ -749,6 +764,8 @@ class TestNLLLossOp1DWithReduce(OpTest): ...@@ -749,6 +764,8 @@ class TestNLLLossOp1DWithReduce(OpTest):
self.init_test_case() self.init_test_case()
self.op_type = "nll_loss" self.op_type = "nll_loss"
self.with_weight = False self.with_weight = False
self.python_api = paddle.nn.functional.nll_loss
self.python_out_sig = ["Out"]
np.random.seed(200) np.random.seed(200)
input_np = np.random.uniform(0.1, 0.8, input_np = np.random.uniform(0.1, 0.8,
self.input_shape).astype("float64") self.input_shape).astype("float64")
...@@ -769,7 +786,7 @@ class TestNLLLossOp1DWithReduce(OpTest): ...@@ -769,7 +786,7 @@ class TestNLLLossOp1DWithReduce(OpTest):
self.attrs = {'reduction': 'mean', 'ignore_index': -100} self.attrs = {'reduction': 'mean', 'ignore_index': -100}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=False)
def test_check_output_with_weight(self): def test_check_output_with_weight(self):
self.with_weight = True self.with_weight = True
...@@ -778,7 +795,7 @@ class TestNLLLossOp1DWithReduce(OpTest): ...@@ -778,7 +795,7 @@ class TestNLLLossOp1DWithReduce(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.with_weight = True self.with_weight = True
place = fluid.CPUPlace() place = fluid.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_eager=False)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out')
...@@ -1014,4 +1031,5 @@ class TestNLLLossInvalidArgs(unittest.TestCase): ...@@ -1014,4 +1031,5 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle
class TestTopkOp(OpTest): class TestTopkOp(OpTest):
...@@ -61,4 +62,5 @@ class TestTopkOp(OpTest): ...@@ -61,4 +62,5 @@ class TestTopkOp(OpTest):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
def numpy_topk(x, k=1, axis=-1, largest=True): def numpy_topk(x, k=1, axis=-1, largest=True):
...@@ -45,6 +46,7 @@ class TestTopkOp(OpTest): ...@@ -45,6 +46,7 @@ class TestTopkOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "top_k_v2" self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64 self.dtype = np.float64
self.input_data = np.random.rand(10, 20) self.input_data = np.random.rand(10, 20)
self.init_args() self.init_args()
...@@ -55,12 +57,10 @@ class TestTopkOp(OpTest): ...@@ -55,12 +57,10 @@ class TestTopkOp(OpTest):
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self): def test_check_output(self):
paddle.enable_static() self.check_output(check_eager=False)
self.check_output()
def test_check_grad(self): def test_check_grad(self):
paddle.enable_static() self.check_grad(set(['X']), 'Out', check_eager=False)
self.check_grad(set(['X']), 'Out')
class TestTopkOp1(TestTopkOp): class TestTopkOp1(TestTopkOp):
...@@ -85,6 +85,7 @@ class TestTopkOp3(OpTest): ...@@ -85,6 +85,7 @@ class TestTopkOp3(OpTest):
def setUp(self): def setUp(self):
self.op_type = "top_k_v2" self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64 self.dtype = np.float64
self.input_data = np.random.rand(16, 100) self.input_data = np.random.rand(16, 100)
self.init_args() self.init_args()
...@@ -103,6 +104,7 @@ class TestTopkOp4(TestTopkOp): ...@@ -103,6 +104,7 @@ class TestTopkOp4(TestTopkOp):
def setUp(self): def setUp(self):
self.op_type = "top_k_v2" self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64 self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5) self.input_data = np.random.rand(10, 10, 5)
self.init_args() self.init_args()
...@@ -121,6 +123,7 @@ class TestTopkOp5(TestTopkOp): ...@@ -121,6 +123,7 @@ class TestTopkOp5(TestTopkOp):
def setUp(self): def setUp(self):
self.op_type = "top_k_v2" self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64 self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5) self.input_data = np.random.rand(10, 10, 5)
self.init_args() self.init_args()
...@@ -139,6 +142,7 @@ class TestTopkOp6(OpTest): ...@@ -139,6 +142,7 @@ class TestTopkOp6(OpTest):
def setUp(self): def setUp(self):
self.op_type = "top_k_v2" self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.dtype = np.float64 self.dtype = np.float64
self.input_data = np.random.rand(80, 16384) self.input_data = np.random.rand(80, 16384)
self.init_args() self.init_args()
...@@ -156,48 +160,64 @@ class TestTopKAPI(unittest.TestCase): ...@@ -156,48 +160,64 @@ class TestTopKAPI(unittest.TestCase):
self.large_input_data = np.random.rand(2, 1030) self.large_input_data = np.random.rand(2, 1030)
def run_dygraph(self, place): def run_dygraph(self, place):
paddle.disable_static(place) with paddle.fluid.dygraph.guard(place):
input_tensor = paddle.to_tensor(self.input_data) input_tensor = paddle.to_tensor(self.input_data)
large_input_tensor = paddle.to_tensor(self.large_input_data) large_input_tensor = paddle.to_tensor(self.large_input_data)
# test case for basic test case 1 # test case for basic test case 1
paddle_result = paddle.topk(input_tensor, k=2) paddle_result = paddle.topk(input_tensor, k=2)
numpy_result = numpy_topk(self.input_data, k=2) numpy_result = numpy_topk(self.input_data, k=2)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) self.assertTrue(
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) np.allclose(paddle_result[0].numpy(), numpy_result[0]))
# test case for basic test case 2 with axis self.assertTrue(
paddle_result = paddle.topk(input_tensor, k=2, axis=1) np.allclose(paddle_result[1].numpy(), numpy_result[1]))
numpy_result = numpy_topk(self.input_data, k=2, axis=1) # test case for basic test case 2 with axis
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) paddle_result = paddle.topk(input_tensor, k=2, axis=1)
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) numpy_result = numpy_topk(self.input_data, k=2, axis=1)
# test case for basic test case 3 with tensor K self.assertTrue(
k_tensor = paddle.to_tensor(np.array([2])) np.allclose(paddle_result[0].numpy(), numpy_result[0]))
paddle_result = paddle.topk(input_tensor, k=k_tensor, axis=1) self.assertTrue(
numpy_result = numpy_topk(self.input_data, k=2, axis=1) np.allclose(paddle_result[1].numpy(), numpy_result[1]))
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) # test case for basic test case 3 with tensor K
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) k_tensor = paddle.to_tensor(np.array([2]))
# test case for basic test case 4 with tensor largest paddle_result = paddle.topk(input_tensor, k=k_tensor, axis=1)
k_tensor = paddle.to_tensor(np.array([2])) numpy_result = numpy_topk(self.input_data, k=2, axis=1)
paddle_result = paddle.topk(input_tensor, k=2, axis=1, largest=False) self.assertTrue(
numpy_result = numpy_topk(self.input_data, k=2, axis=1, largest=False) np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) self.assertTrue(
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 5 with axis -1 # test case for basic test case 4 with tensor largest
k_tensor = paddle.to_tensor(np.array([2])) k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=2, axis=-1, largest=False) paddle_result = paddle.topk(
numpy_result = numpy_topk(self.input_data, k=2, axis=-1, largest=False) input_tensor, k=2, axis=1, largest=False)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) numpy_result = numpy_topk(
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) self.input_data, k=2, axis=1, largest=False)
# test case for basic test case 6 for the partial sort self.assertTrue(
paddle_result = paddle.topk(large_input_tensor, k=1, axis=-1) np.allclose(paddle_result[0].numpy(), numpy_result[0]))
numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1) self.assertTrue(
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) np.allclose(paddle_result[1].numpy(), numpy_result[1]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) # test case for basic test case 5 with axis -1
# test case for basic test case 7 for the unsorted k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=2, axis=1, sorted=False) paddle_result = paddle.topk(
sort_paddle = numpy_topk( input_tensor, k=2, axis=-1, largest=False)
np.array(paddle_result[0].numpy()), axis=1, k=2) numpy_result = numpy_topk(
numpy_result = numpy_topk(self.input_data, k=2, axis=1) self.input_data, k=2, axis=-1, largest=False)
self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) self.assertTrue(
np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(
np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 6 for the partial sort
paddle_result = paddle.topk(large_input_tensor, k=1, axis=-1)
numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1)
self.assertTrue(
np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(
np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 7 for the unsorted
paddle_result = paddle.topk(input_tensor, k=2, axis=1, sorted=False)
sort_paddle = numpy_topk(
np.array(paddle_result[0].numpy()), axis=1, k=2)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0]))
def run_static(self, place): def run_static(self, place):
paddle.enable_static() paddle.enable_static()
...@@ -264,14 +284,15 @@ class TestTopKAPI(unittest.TestCase): ...@@ -264,14 +284,15 @@ class TestTopKAPI(unittest.TestCase):
self.run_static(place) self.run_static(place)
def test_errors(self): def test_errors(self):
paddle.disable_static() with paddle.fluid.dygraph.guard():
x = paddle.to_tensor([1, 2, 3]) x = paddle.to_tensor([1, 2, 3])
with self.assertRaises(BaseException): with self.assertRaises(BaseException):
paddle.topk(x, k=-1) paddle.topk(x, k=-1)
with self.assertRaises(BaseException): with self.assertRaises(BaseException):
paddle.topk(x, k=0) paddle.topk(x, k=0)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -74,6 +74,7 @@ class TestViterbiOp(OpTest): ...@@ -74,6 +74,7 @@ class TestViterbiOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "viterbi_decode" self.op_type = "viterbi_decode"
self.python_api = paddle.text.viterbi_decode
self.set_attr() self.set_attr()
bz, length, ntags = self.bz, self.len, self.ntags bz, length, ntags = self.bz, self.len, self.ntags
self.input = np.random.randn(bz, length, ntags).astype(self.dtype) self.input = np.random.randn(bz, length, ntags).astype(self.dtype)
...@@ -90,7 +91,7 @@ class TestViterbiOp(OpTest): ...@@ -90,7 +91,7 @@ class TestViterbiOp(OpTest):
self.outputs = {'Scores': scores, 'Path': path} self.outputs = {'Scores': scores, 'Path': path}
def test_output(self): def test_output(self):
self.check_output() self.check_output(check_eager=True)
class TestViterbiAPI(unittest.TestCase): class TestViterbiAPI(unittest.TestCase):
...@@ -132,3 +133,8 @@ class TestViterbiAPI(unittest.TestCase): ...@@ -132,3 +133,8 @@ class TestViterbiAPI(unittest.TestCase):
def test_static_net(self): def test_static_net(self):
for place in self.places: for place in self.places:
self.check_static_result(place) self.check_static_result(place)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
...@@ -31,7 +31,7 @@ def YoloBox(x, img_size, attrs): ...@@ -31,7 +31,7 @@ def YoloBox(x, img_size, attrs):
an_num = int((len(anchors) // 2)) an_num = int((len(anchors) // 2))
class_num = attrs['class_num'] class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh'] conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample'] downsample = attrs['downsample_ratio']
clip_bbox = attrs['clip_bbox'] clip_bbox = attrs['clip_bbox']
scale_x_y = attrs['scale_x_y'] scale_x_y = attrs['scale_x_y']
iou_aware = attrs['iou_aware'] iou_aware = attrs['iou_aware']
...@@ -92,13 +92,14 @@ class TestYoloBoxOp(OpTest): ...@@ -92,13 +92,14 @@ class TestYoloBoxOp(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
self.op_type = 'yolo_box' self.op_type = 'yolo_box'
self.python_api = paddle.vision.ops.yolo_box
x = np.random.random(self.x_shape).astype('float32') x = np.random.random(self.x_shape).astype('float32')
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32') img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
self.attrs = { self.attrs = {
'anchors': self.anchors, 'anchors': self.anchors,
'class_num': self.class_num, 'class_num': self.class_num,
'conf_thresh': self.conf_thresh, 'conf_thresh': self.conf_thresh,
'downsample': self.downsample, 'downsample_ratio': self.downsample,
'clip_bbox': self.clip_bbox, 'clip_bbox': self.clip_bbox,
'scale_x_y': self.scale_x_y, 'scale_x_y': self.scale_x_y,
'iou_aware': self.iou_aware, 'iou_aware': self.iou_aware,
......
...@@ -28,7 +28,7 @@ from ...tensor import clip ...@@ -28,7 +28,7 @@ from ...tensor import clip
from ...tensor import sum from ...tensor import sum
from ...tensor import sqrt from ...tensor import sqrt
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
from ...fluid.framework import _varbase_creator from ...fluid.framework import _varbase_creator, _in_legacy_dygraph, in_dygraph_mode
from ...fluid import dygraph_utils from ...fluid import dygraph_utils
from ...fluid import layers from ...fluid import layers
...@@ -1616,7 +1616,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): ...@@ -1616,7 +1616,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
if epsilon > 1. or epsilon < 0.: if epsilon > 1. or epsilon < 0.:
raise ValueError("The value of epsilon must be between 0 and 1.") raise ValueError("The value of epsilon must be between 0 and 1.")
if in_dynamic_mode(): if paddle.in_dynamic_mode():
return _C_ops.label_smooth(label, prior_dist, 'epsilon', float(epsilon)) return _C_ops.label_smooth(label, prior_dist, 'epsilon', float(epsilon))
check_variable_and_dtype(label, 'label', ['float32', 'float64'], check_variable_and_dtype(label, 'label', ['float32', 'float64'],
......
...@@ -37,7 +37,7 @@ from paddle.utils import deprecated ...@@ -37,7 +37,7 @@ from paddle.utils import deprecated
from paddle import _C_ops from paddle import _C_ops
from paddle import in_dynamic_mode from paddle import in_dynamic_mode
from paddle.framework import core from paddle.framework import core
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
__all__ = [] __all__ = []
...@@ -784,11 +784,12 @@ def nll_loss(input, ...@@ -784,11 +784,12 @@ def nll_loss(input,
input_dims)) input_dims))
n = input_shape[0] n = input_shape[0]
c = input_shape[1] c = input_shape[1]
if in_dynamic_mode(): if _non_static_mode():
if input_dims != 2 and input_dims != 4: if input_dims != 2 and input_dims != 4:
input, _ = _C_ops.reshape2(input, None, 'shape', [n, c, 1, -1]) input, _ = _C_ops.reshape2(input, None, 'shape', [n, c, 1, -1])
label, _ = _C_ops.reshape2(label, None, 'shape', [n, 1, -1]) label, _ = _C_ops.reshape2(label, None, 'shape', [n, 1, -1])
out_shape = [n] + input_shape[2:] out_shape = [n] + input_shape[2:]
out, total_weight = _C_ops.nll_loss(input, label, weight, out, total_weight = _C_ops.nll_loss(input, label, weight,
'ignore_index', ignore_index, 'ignore_index', ignore_index,
'reduction', reduction) 'reduction', reduction)
......
...@@ -181,7 +181,7 @@ def batch_norm(x, ...@@ -181,7 +181,7 @@ def batch_norm(x,
trainable_statistics = not use_global_stats trainable_statistics = not use_global_stats
if in_dynamic_mode(): if in_dynamic_mode():
# for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "is_test", attrs = ("momentum", momentum, "epsilon", epsilon, "is_test",
not training, "data_layout", data_format, "use_mkldnn", False, not training, "data_layout", data_format, "use_mkldnn", False,
"fuse_with_relu", False, "use_global_stats", use_global_stats, "fuse_with_relu", False, "use_global_stats", use_global_stats,
......
...@@ -1397,7 +1397,10 @@ def histogram(input, bins=100, min=0, max=0, name=None): ...@@ -1397,7 +1397,10 @@ def histogram(input, bins=100, min=0, max=0, name=None):
result = paddle.histogram(inputs, bins=4, min=0, max=3) result = paddle.histogram(inputs, bins=4, min=0, max=3)
print(result) # [0, 2, 1, 0] print(result) # [0, 2, 1, 0]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_histogram(input, bins, min, max)
if _in_legacy_dygraph():
return _C_ops.histogram(input, "bins", bins, "min", min, "max", max) return _C_ops.histogram(input, "bins", bins, "min", min, "max", max)
helper = LayerHelper('histogram', **locals()) helper = LayerHelper('histogram', **locals())
......
...@@ -536,6 +536,8 @@ def bitwise_and(x, y, out=None, name=None): ...@@ -536,6 +536,8 @@ def bitwise_and(x, y, out=None, name=None):
res = paddle.bitwise_and(x, y) res = paddle.bitwise_and(x, y)
print(res) # [0, 2, 1] print(res) # [0, 2, 1]
""" """
if in_dygraph_mode() and out == None:
return _C_ops.final_state_bitwise_and(x, y)
return _bitwise_op( return _bitwise_op(
op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True) op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True)
...@@ -562,6 +564,9 @@ def bitwise_or(x, y, out=None, name=None): ...@@ -562,6 +564,9 @@ def bitwise_or(x, y, out=None, name=None):
res = paddle.bitwise_or(x, y) res = paddle.bitwise_or(x, y)
print(res) # [-1, -1, -3] print(res) # [-1, -1, -3]
""" """
if in_dygraph_mode() and out == None:
return _C_ops.final_state_bitwise_or(x, y)
return _bitwise_op( return _bitwise_op(
op_name="bitwise_or", x=x, y=y, name=name, out=out, binary_op=True) op_name="bitwise_or", x=x, y=y, name=name, out=out, binary_op=True)
...@@ -588,6 +593,8 @@ def bitwise_xor(x, y, out=None, name=None): ...@@ -588,6 +593,8 @@ def bitwise_xor(x, y, out=None, name=None):
res = paddle.bitwise_xor(x, y) res = paddle.bitwise_xor(x, y)
print(res) # [-1, -3, -4] print(res) # [-1, -3, -4]
""" """
if in_dygraph_mode() and out == None:
return _C_ops.final_state_bitwise_xor(x, y)
return _bitwise_op( return _bitwise_op(
op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True) op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True)
...@@ -612,6 +619,8 @@ def bitwise_not(x, out=None, name=None): ...@@ -612,6 +619,8 @@ def bitwise_not(x, out=None, name=None):
res = paddle.bitwise_not(x) res = paddle.bitwise_not(x)
print(res) # [4, 0, -2] print(res) # [4, 0, -2]
""" """
if in_dygraph_mode() and out == None:
return _C_ops.final_state_bitwise_not(x)
return _bitwise_op( return _bitwise_op(
op_name="bitwise_not", x=x, y=None, name=name, out=out, binary_op=False) op_name="bitwise_not", x=x, y=None, name=name, out=out, binary_op=False)
......
...@@ -17,7 +17,7 @@ from collections import Counter ...@@ -17,7 +17,7 @@ from collections import Counter
from ..static import Variable, device_guard from ..static import Variable, device_guard
from ..framework import core from ..framework import core
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _in_eager_without_dygraph_check from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _in_eager_without_dygraph_check, _non_static_mode
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..framework import OpProtoHolder, convert_np_dtype_to_dtype_, dygraph_only from ..framework import OpProtoHolder, convert_np_dtype_to_dtype_, dygraph_only
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
...@@ -1845,7 +1845,7 @@ def expand_as(x, y, name=None): ...@@ -1845,7 +1845,7 @@ def expand_as(x, y, name=None):
np_out = out.numpy() np_out = out.numpy()
# [[1, 2, 3], [1, 2, 3]] # [[1, 2, 3], [1, 2, 3]]
""" """
if paddle.in_dynamic_mode(): if _non_static_mode():
return _C_ops.expand_as_v2(x, 'target_shape', y.shape) return _C_ops.expand_as_v2(x, 'target_shape', y.shape)
check_variable_and_dtype( check_variable_and_dtype(
......
...@@ -2681,7 +2681,9 @@ def isfinite(x, name=None): ...@@ -2681,7 +2681,9 @@ def isfinite(x, name=None):
out = paddle.tensor.isfinite(x) out = paddle.tensor.isfinite(x)
print(out) # [False True True False True False False] print(out) # [False True True False True False False]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_isfinite( x )
if _in_legacy_dygraph():
return _C_ops.isfinite_v2(x) return _C_ops.isfinite_v2(x)
helper = LayerHelper("isfinite_v2", **locals()) helper = LayerHelper("isfinite_v2", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isfinite') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isfinite')
...@@ -2709,7 +2711,9 @@ def isinf(x, name=None): ...@@ -2709,7 +2711,9 @@ def isinf(x, name=None):
out = paddle.tensor.isinf(x) out = paddle.tensor.isinf(x)
print(out) # [ True False False True False False False] print(out) # [ True False False True False False False]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_isinf( x )
if _in_legacy_dygraph():
return _C_ops.isinf_v2(x) return _C_ops.isinf_v2(x)
helper = LayerHelper("isinf_v2", **locals()) helper = LayerHelper("isinf_v2", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isinf') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isinf')
...@@ -2737,7 +2741,10 @@ def isnan(x, name=None): ...@@ -2737,7 +2741,10 @@ def isnan(x, name=None):
out = paddle.tensor.isnan(x) out = paddle.tensor.isnan(x)
print(out) # [False False False False False True True] print(out) # [False False False False False True True]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_isnan( x )
if _in_legacy_dygraph():
return _C_ops.isnan_v2(x) return _C_ops.isnan_v2(x)
helper = LayerHelper("isnan_v2", **locals()) helper = LayerHelper("isnan_v2", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isnan') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isnan')
...@@ -3387,8 +3394,13 @@ def lerp(x, y, weight, name=None): ...@@ -3387,8 +3394,13 @@ def lerp(x, y, weight, name=None):
# out: [5.5., 6., 6.5, 7.] # out: [5.5., 6., 6.5, 7.]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp') check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp')
if isinstance(weight, float):
weight = paddle.to_tensor(weight, dtype=x.dtype)
return _C_ops.final_state_lerp( x, y, weight)
if _in_legacy_dygraph():
if isinstance(weight, float): if isinstance(weight, float):
weight = paddle.to_tensor(weight, dtype=x.dtype) weight = paddle.to_tensor(weight, dtype=x.dtype)
return _C_ops.lerp(x, y, weight) return _C_ops.lerp(x, y, weight)
......
...@@ -18,7 +18,7 @@ from ..fluid.layer_helper import LayerHelper ...@@ -18,7 +18,7 @@ from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
from ..fluid import layers from ..fluid import layers
from ..framework import core from ..framework import core
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
from paddle.common_ops_import import convert_np_dtype_to_dtype_ from paddle.common_ops_import import convert_np_dtype_to_dtype_
from paddle.common_ops_import import Variable from paddle.common_ops_import import Variable
from paddle.common_ops_import import VarDesc from paddle.common_ops_import import VarDesc
...@@ -774,7 +774,10 @@ def masked_select(x, mask, name=None): ...@@ -774,7 +774,10 @@ def masked_select(x, mask, name=None):
#[1.0 5.0 6.0 9.0] #[1.0 5.0 6.0 9.0]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_masked_select(x, mask)
if _in_legacy_dygraph():
return _C_ops.masked_select(x, mask) return _C_ops.masked_select(x, mask)
helper = LayerHelper("masked_select", **locals()) helper = LayerHelper("masked_select", **locals())
...@@ -844,8 +847,8 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None): ...@@ -844,8 +847,8 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None):
# [[1 1 0 0]] # [[1 1 0 0]]
""" """
if paddle.in_dynamic_mode():
k = k.numpy().item(0) if isinstance(k, Variable) else k if _non_static_mode():
if axis is None: if axis is None:
out, indices = _C_ops.top_k_v2(x, 'k', out, indices = _C_ops.top_k_v2(x, 'k',
int(k), 'largest', largest, 'sorted', int(k), 'largest', largest, 'sorted',
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from ..nn import Layer from ..nn import Layer
from ..fluid.framework import core, _non_static_mode from ..fluid.framework import core, _non_static_mode, in_dygraph_mode
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type from ..fluid.data_feeder import check_variable_and_dtype, check_type
from paddle import _C_ops from paddle import _C_ops
...@@ -58,6 +58,10 @@ def viterbi_decode(potentials, ...@@ -58,6 +58,10 @@ def viterbi_decode(potentials,
transition = paddle.rand((num_tags, num_tags), dtype='float32') transition = paddle.rand((num_tags, num_tags), dtype='float32')
scores, path = paddle.text.viterbi_decode(emission, transition, length, False) # scores: [3.37089300, 1.56825531], path: [[1, 0, 0], [1, 1, 0]] scores, path = paddle.text.viterbi_decode(emission, transition, length, False) # scores: [3.37089300, 1.56825531], path: [[1, 0, 0], [1, 1, 0]]
""" """
if in_dygraph_mode():
return _C_ops.final_state_viterbi_decode(potentials, transition_params,
lengths, include_bos_eos_tag)
if _non_static_mode(): if _non_static_mode():
return _C_ops.viterbi_decode(potentials, transition_params, lengths, return _C_ops.viterbi_decode(potentials, transition_params, lengths,
'include_bos_eos_tag', include_bos_eos_tag) 'include_bos_eos_tag', include_bos_eos_tag)
......
...@@ -547,6 +547,15 @@ ...@@ -547,6 +547,15 @@
func : hard_sigmoid func : hard_sigmoid
backward : hard_sigmoid_grad backward : hard_sigmoid_grad
# histogram
- api : histogram
args : (Tensor x, int64_t bins, int min, int max)
output : Tensor
infer_meta :
func : HistogramInferMeta
kernel :
func : histogram
- api : huber_loss - api : huber_loss
args : (Tensor input, Tensor label, float delta) args : (Tensor input, Tensor label, float delta)
output : Tensor(out), Tensor(residual) output : Tensor(out), Tensor(residual)
......
...@@ -19,7 +19,7 @@ from ..fluid import core, layers ...@@ -19,7 +19,7 @@ from ..fluid import core, layers
from ..fluid.layers import nn, utils from ..fluid.layers import nn, utils
from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D
from ..fluid.initializer import Normal from ..fluid.initializer import Normal
from ..fluid.framework import _non_static_mode from ..fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.common_ops_import import * from paddle.common_ops_import import *
from paddle import _C_ops from paddle import _C_ops
...@@ -377,6 +377,12 @@ def yolo_box(x, ...@@ -377,6 +377,12 @@ def yolo_box(x,
clip_bbox=True, clip_bbox=True,
scale_x_y=1.) scale_x_y=1.)
""" """
if in_dygraph_mode():
boxes, scores = _C_ops.final_state_yolo_box(
x, img_size, anchors, class_num, conf_thresh, downsample_ratio,
clip_bbox, scale_x_y, iou_aware, iou_aware_factor)
return boxes, scores
if _non_static_mode(): if _non_static_mode():
boxes, scores = _C_ops.yolo_box( boxes, scores = _C_ops.yolo_box(
x, img_size, 'anchors', anchors, 'class_num', class_num, x, img_size, 'anchors', anchors, 'class_num', class_num,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册