未验证 提交 c74aaf67 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] Support 0D for...

[Zero-Dim] Support 0D for numel/rank/size/optimizer/create_parameter/create_global_var, fix some usage to adapt 0D (#51566)
上级 c658be9b
...@@ -343,8 +343,9 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -343,8 +343,9 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
inline bool recompute_reduce_all(const DenseTensor& x, inline bool recompute_reduce_all(const DenseTensor& x,
const IntArray& dims, const IntArray& dims,
bool reduce_all = false) { bool reduce_all = false) {
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size() || if (dims.size() == 0 || x.dims().size() == 0 ||
reduce_all) { static_cast<int>(dims.size()) == x.dims().size() || reduce_all) {
// when input 0D, it can only reduce_all
return true; return true;
} else { } else {
return false; return false;
......
...@@ -2963,24 +2963,32 @@ DDim ReduceInferDim(const MetaTensor& x, ...@@ -2963,24 +2963,32 @@ DDim ReduceInferDim(const MetaTensor& x,
std::vector<int64_t> formated_axis = axis; std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
PADDLE_ENFORCE_LT(axis[i], if (x_rank == 0) {
x_rank, PADDLE_ENFORCE_EQ(
errors::InvalidArgument( axis[i] == 0 || axis[i] == -1,
"The reduce dim index %d should be in the " true,
"range [ -dimension(X), dimension(X) ) " phi::errors::InvalidArgument(
"which dimesion = %d. But received dim index = %d.", "When input 0D Tensor, the axis can only be -1, 0, None or []"));
i, } else {
x_rank, PADDLE_ENFORCE_LT(axis[i],
axis[i])); x_rank,
PADDLE_ENFORCE_GE(axis[i], errors::InvalidArgument(
-x_rank, "The reduce dim index %d should be in the "
errors::InvalidArgument( "range [ -dimension(X), dimension(X) ) "
"The reduce dim index %d should be in the " "which dimesion = %d. But received dim index = %d.",
"range [ -dimension(X), dimension(X) ) " i,
"which dimesion = %d. But received dim index = %d.", x_rank,
i, axis[i]));
x_rank, PADDLE_ENFORCE_GE(axis[i],
axis[i])); -x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
}
if (axis[i] < 0) { if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank; formated_axis[i] = axis[i] + x_rank;
...@@ -3369,12 +3377,7 @@ void ShardIndexInferMeta(const MetaTensor& in, ...@@ -3369,12 +3377,7 @@ void ShardIndexInferMeta(const MetaTensor& in,
void NumelInferMeta(const MetaTensor& input, MetaTensor* out) { void NumelInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64); out->set_dtype(DataType::INT64);
if (input.dims().size() == 0) { out->set_dims(phi::make_ddim({}));
out->set_dims(phi::make_ddim({}));
} else {
// TODO(zhouwei): will change shape [1] to [] to support zero-dim
out->set_dims(phi::make_ddim({1}));
}
} }
void SliceRawInferMeta(const MetaTensor& input, void SliceRawInferMeta(const MetaTensor& input,
......
...@@ -22,53 +22,6 @@ ...@@ -22,53 +22,6 @@
#include "paddle/phi/kernels/impl/reduce_grad.h" #include "paddle/phi/kernels/impl/reduce_grad.h"
namespace phi { namespace phi {
template <typename T, typename Context>
void ComputeFromInput(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& input2,
const std::vector<int64_t>& dims,
DenseTensor* x_grad) {
auto* input0 = &x;
auto* output = x_grad;
dev_ctx.template Alloc<T>(output);
const auto* input2_d = input2.data<T>();
auto* output_d = output->data<T>();
// handle reduce_all
if (input2.dims().size() == 1 && input2.dims()[0] == 1) {
for (int64_t i = 0; i < phi::product(input0->dims()); ++i) {
output_d[i] = input2_d[0];
}
return;
}
// handle reduce by one dimension
int reduce_dim_index = dims[0];
if (reduce_dim_index < 0) {
reduce_dim_index += input0->dims().size();
}
auto& input_dim = input0->dims();
int64_t before_dim = 1;
for (int i = 0; i < reduce_dim_index; ++i) {
before_dim *= input_dim[i];
}
int64_t reduce_dim = input_dim[reduce_dim_index];
int64_t after_dim = 1;
for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) {
after_dim *= input_dim[i];
}
for (int64_t i = 0; i < before_dim; ++i) {
for (int64_t j = 0; j < reduce_dim; ++j) {
for (int64_t k = 0; k < after_dim; ++k) {
output_d[i * reduce_dim * after_dim + j * after_dim + k] =
input2_d[i * after_dim + k];
}
}
}
}
template <typename T, typename Context> template <typename T, typename Context>
void ReduceSumGradKernel(const Context& dev_ctx, void ReduceSumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -78,24 +31,6 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -78,24 +31,6 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all); reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (dims.size() == 1) {
if (out_grad.dtype() != x.dtype()) {
DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
DenseTensor x_grad_tmp =
phi::Empty<Context>(dev_ctx, std::move(x_grad_meta));
ComputeFromInput<T, Context>(
dev_ctx, x, out_grad, dims.GetData(), &x_grad_tmp);
phi::CastKernel<T>(dev_ctx, x_grad_tmp, x.dtype(), x_grad);
} else {
ComputeFromInput<T, Context>(
dev_ctx, x, out_grad, dims.GetData(), x_grad);
}
}
ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx, ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx,
x, x,
paddle::none, paddle::none,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import math import math
import warnings import warnings
import numpy as np
import paddle import paddle
from .. import unique_name from .. import unique_name
...@@ -951,10 +952,9 @@ class ReduceLROnPlateau(LearningRateDecay): ...@@ -951,10 +952,9 @@ class ReduceLROnPlateau(LearningRateDecay):
# loss must be 1-D Tensor with shape [1] # loss must be 1-D Tensor with shape [1]
check_type(loss, 'loss', Variable, 'ReduceLROnPlateau.step') check_type(loss, 'loss', Variable, 'ReduceLROnPlateau.step')
assert len(loss.shape) == 1 and loss.shape[0] == 1, ( assert np.prod(loss.shape) == 1, (
"the loss.shape " "The number of elements of loss should be 1, but the current loss.shape is {}, whose number of elements is not 1. "
"should be (1L,), but the current loss.shape is {}. Maybe that " "Maybe that you should call paddle.mean to process it first.".format(
"you should call paddle.mean to process it first.".format(
loss.shape loss.shape
) )
) )
......
...@@ -728,7 +728,8 @@ def monkey_patch_varbase(): ...@@ -728,7 +728,8 @@ def monkey_patch_varbase():
return framework.default_main_program().global_block() return framework.default_main_program().global_block()
def __nonzero__(self): def __nonzero__(self):
numel = np.prod(self.shape) # np.prod([]) -> np.float64, so use int
numel = int(np.prod(self.shape))
assert ( assert (
numel == 1 numel == 1
), "When Variable is used as the condition of if/while , Variable can only contain one element." ), "When Variable is used as the condition of if/while , Variable can only contain one element."
......
...@@ -1115,8 +1115,8 @@ class Optimizer: ...@@ -1115,8 +1115,8 @@ class Optimizer:
else: else:
assert isinstance(callbacks, list) assert isinstance(callbacks, list)
program = loss.block.program program = loss.block.program
assert len(loss.shape) == 1 and loss.shape[0] == 1, ( assert np.prod(loss.shape) == 1, (
"The loss.shape should be (1L,), but the current loss.shape is {}. " "The number of elements of loss should be 1, but the current loss.shape is {}, whose number of elements is not 1. "
"Maybe that you should call paddle.mean to process the current loss.".format( "Maybe that you should call paddle.mean to process the current loss.".format(
loss.shape loss.shape
) )
......
...@@ -115,7 +115,7 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16): ...@@ -115,7 +115,7 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
) )
if avg_loss_value.dtype == numpy.uint16: if avg_loss_value.dtype == numpy.uint16:
avg_loss_value = convert_uint16_to_float(avg_loss_value) avg_loss_value = convert_uint16_to_float(avg_loss_value)
if avg_loss_value[0] < 10.0: if float(avg_loss_value) < 10.0:
if save_dirname is not None: if save_dirname is not None:
paddle.static.save_inference_model( paddle.static.save_inference_model(
save_dirname, save_dirname,
......
...@@ -263,7 +263,7 @@ def train(use_cuda, save_dirname, is_local=True): ...@@ -263,7 +263,7 @@ def train(use_cuda, save_dirname, is_local=True):
) )
return return
if math.isnan(float(out[0])): if math.isnan(float(out)):
sys.exit("got NaN loss, training failed.") sys.exit("got NaN loss, training failed.")
if is_local: if is_local:
......
...@@ -52,7 +52,7 @@ class SimpleNet(paddle.nn.Layer): ...@@ -52,7 +52,7 @@ class SimpleNet(paddle.nn.Layer):
def forward(self, x): def forward(self, x):
is_use = ( is_use = (
paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).numpy()[0] paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).item()
and self.trainer_id == 1 and self.trainer_id == 1
) )
......
...@@ -105,18 +105,12 @@ for _type_name in {'float32', 'float64', 'int32', 'int64', 'bool'}: ...@@ -105,18 +105,12 @@ for _type_name in {'float32', 'float64', 'int32', 'int64', 'bool'}:
class TestEqualReduceAPI(unittest.TestCase): class TestEqualReduceAPI(unittest.TestCase):
def test_name(self):
x = paddle.assign(np.array([3, 4], dtype="int32"))
y = paddle.assign(np.array([3, 4], dtype="int32"))
out = paddle.equal_all(x, y, name='equal_res')
assert 'equal_res' in out.name
def test_dynamic_api(self): def test_dynamic_api(self):
paddle.disable_static() paddle.disable_static()
x = paddle.ones(shape=[10, 10], dtype="int32") x = paddle.ones(shape=[10, 10], dtype="int32")
y = paddle.ones(shape=[10, 10], dtype="int32") y = paddle.ones(shape=[10, 10], dtype="int32")
out = paddle.equal_all(x, y) out = paddle.equal_all(x, y)
assert out.numpy()[0] is np.True_ assert out.item() is True
paddle.enable_static() paddle.enable_static()
......
...@@ -48,7 +48,7 @@ class TestCompiledProgram(unittest.TestCase): ...@@ -48,7 +48,7 @@ class TestCompiledProgram(unittest.TestCase):
feed={"image": self.img, "label": self.label}, feed={"image": self.img, "label": self.label},
fetch_list=[loss.name], fetch_list=[loss.name],
) )
self.loss = loss_data[0] self.loss = float(loss_data)
def test_compiled_program_base(self): def test_compiled_program_base(self):
with new_program_scope(): with new_program_scope():
...@@ -70,7 +70,7 @@ class TestCompiledProgram(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestCompiledProgram(unittest.TestCase):
feed={"image": self.img, "label": self.label}, feed={"image": self.img, "label": self.label},
fetch_list=[loss.name], fetch_list=[loss.name],
) )
np.testing.assert_array_equal(loss_data[0], self.loss) np.testing.assert_array_equal(float(loss_data), self.loss)
class TestCompiledProgramError(unittest.TestCase): class TestCompiledProgramError(unittest.TestCase):
......
...@@ -553,7 +553,7 @@ class DyGraphTrainModel: ...@@ -553,7 +553,7 @@ class DyGraphTrainModel:
self.clear_gradients() self.clear_gradients()
return g_loss.numpy()[0], d_loss.numpy()[0] return float(g_loss), float(d_loss)
class StaticGraphTrainModel: class StaticGraphTrainModel:
......
...@@ -30,7 +30,6 @@ class TestNumelOp(OpTest): ...@@ -30,7 +30,6 @@ class TestNumelOp(OpTest):
self.inputs = { self.inputs = {
'Input': x, 'Input': x,
} }
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
self.outputs = {'Out': np.array([np.size(x)])} self.outputs = {'Out': np.array([np.size(x)])}
def test_check_output(self): def test_check_output(self):
...@@ -73,10 +72,10 @@ class TestNumelAPI(unittest.TestCase): ...@@ -73,10 +72,10 @@ class TestNumelAPI(unittest.TestCase):
) )
# TODO(zhouwei): will change shape [1] to [] to support zero-dim # TODO(zhouwei): will change shape [1] to [] to support zero-dim
assert np.array_equal( assert np.array_equal(
res_1, np.array([np.size(input_1)]).astype("int64") res_1, np.array(np.size(input_1)).astype("int64")
) )
assert np.array_equal( assert np.array_equal(
res_2, np.array([np.size(input_2)]).astype("int64") res_2, np.array(np.size(input_2)).astype("int64")
) )
def test_numel_imperative(self): def test_numel_imperative(self):
......
...@@ -33,7 +33,7 @@ class TestSizeOp(OpTest): ...@@ -33,7 +33,7 @@ class TestSizeOp(OpTest):
self.config() self.config()
input = np.zeros(self.shape, dtype='bool') input = np.zeros(self.shape, dtype='bool')
self.inputs = {'Input': input} self.inputs = {'Input': input}
self.outputs = {'Out': np.array([np.size(input)], dtype='int64')} self.outputs = {'Out': np.array(np.size(input), dtype='int64')}
def config(self): def config(self):
pass pass
...@@ -85,10 +85,10 @@ class TestSizeAPI(unittest.TestCase): ...@@ -85,10 +85,10 @@ class TestSizeAPI(unittest.TestCase):
) )
# TODO(zhouwei): will change shape [1] to [] to support zero-dim # TODO(zhouwei): will change shape [1] to [] to support zero-dim
assert np.array_equal( assert np.array_equal(
res_1, np.array([np.size(input_1)]).astype("int64") res_1, np.array(np.size(input_1)).astype("int64")
) )
assert np.array_equal( assert np.array_equal(
res_2, np.array([np.size(input_2)]).astype("int64") res_2, np.array(np.size(input_2)).astype("int64")
) )
def test_size_imperative(self): def test_size_imperative(self):
......
...@@ -204,6 +204,20 @@ class TestReduceAPI(unittest.TestCase): ...@@ -204,6 +204,20 @@ class TestReduceAPI(unittest.TestCase):
np.testing.assert_allclose(x.grad.numpy(), np.array(1.0)) np.testing.assert_allclose(x.grad.numpy(), np.array(1.0))
np.testing.assert_allclose(out.grad.numpy(), np.array(1.0)) np.testing.assert_allclose(out.grad.numpy(), np.array(1.0))
out1 = api(x, 0)
self.assertEqual(out1.shape, [])
self.assertEqual(out1, out)
out1.backward()
out2 = api(x, -1)
self.assertEqual(out2.shape, [])
self.assertEqual(out2, out)
out2.backward()
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))
paddle.enable_static() paddle.enable_static()
def test_static_reduce(self): def test_static_reduce(self):
...@@ -227,6 +241,12 @@ class TestReduceAPI(unittest.TestCase): ...@@ -227,6 +241,12 @@ class TestReduceAPI(unittest.TestCase):
out_empty_list = api(x, None) out_empty_list = api(x, None)
self.assertEqual(out_empty_list.shape, ()) self.assertEqual(out_empty_list.shape, ())
out1 = api(x, 0)
self.assertEqual(out1.shape, ())
out2 = api(x, -1)
self.assertEqual(out2.shape, ())
fetch_list = [x, out] fetch_list = [x, out]
if block.has_var(x.grad_name): if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name]) fetch_list.extend([x.grad_name, out.grad_name])
...@@ -550,6 +570,16 @@ class TestSundryAPI(unittest.TestCase): ...@@ -550,6 +570,16 @@ class TestSundryAPI(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
self.x = paddle.rand([]) self.x = paddle.rand([])
def test_create_parameter_var(self):
zero_dim_param = paddle.create_parameter(shape=[], dtype='float32')
self.assertEqual(zero_dim_param.shape, [])
zero_dim_var = paddle.tensor.creation.create_global_var(
shape=[], value=0.5, dtype='float32'
)
self.assertEqual(zero_dim_var.shape, [])
self.assertEqual(zero_dim_var.item(), 0.5)
def test_expand(self): def test_expand(self):
# case1 # case1
x = paddle.full([], 1, 'float32') x = paddle.full([], 1, 'float32')
...@@ -963,15 +993,29 @@ class TestSundryAPI(unittest.TestCase): ...@@ -963,15 +993,29 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(x_np, np.array(0.5)) np.testing.assert_array_equal(x_np, np.array(0.5))
def test_numel(self): def test_numel(self):
# 1) x is 0D
out = paddle.numel(self.x) out = paddle.numel(self.x)
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(1)) np.testing.assert_array_equal(out.numpy(), np.array(1))
# 2) x is ND
x = paddle.full([3, 5], 0.5)
out = paddle.numel(x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(15))
def test_rank(self): def test_rank(self):
# 1) x is 0D
out = paddle.rank(self.x) out = paddle.rank(self.x)
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(0)) np.testing.assert_array_equal(out.numpy(), np.array(0))
# 1) x is ND
x = paddle.full([3, 5], 0.5)
out = paddle.rank(x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(2))
def test_shape(self): def test_shape(self):
out = paddle.shape(self.x) out = paddle.shape(self.x)
self.assertEqual(out.shape, [0]) self.assertEqual(out.shape, [0])
...@@ -1898,6 +1942,23 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1898,6 +1942,23 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
self.exe = paddle.static.Executor() self.exe = paddle.static.Executor()
@prog_scope()
def test_create_parameter_var(self):
zero_dim_param = paddle.create_parameter(shape=[], dtype='float32')
self.assertEqual(zero_dim_param.shape, ())
prog = paddle.static.default_startup_program()
res = self.exe.run(prog, fetch_list=[zero_dim_param])
self.assertEqual(res[0].shape, ())
zero_dim_var = paddle.static.create_global_var(
shape=[], value=0.5, dtype='float32'
)
self.assertEqual(zero_dim_var.shape, ())
prog = paddle.static.default_startup_program()
res = self.exe.run(prog, fetch_list=[zero_dim_var])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 0.5)
@prog_scope() @prog_scope()
def test_expand(self): def test_expand(self):
x = paddle.full([], 1, 'float32') x = paddle.full([], 1, 'float32')
...@@ -3142,6 +3203,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3142,6 +3203,7 @@ class TestSundryAPIStatic(unittest.TestCase):
res = self.exe.run(prog, feed={"x": x_tensor}, fetch_list=[out]) res = self.exe.run(prog, feed={"x": x_tensor}, fetch_list=[out])
self.assertEqual(res[0].shape, (3, 4, 2)) self.assertEqual(res[0].shape, (3, 4, 2))
@prog_scope()
def test_prelu(self): def test_prelu(self):
x1 = paddle.full([], 1.0, 'float32') x1 = paddle.full([], 1.0, 'float32')
x1.stop_gradient = False x1.stop_gradient = False
...@@ -3174,6 +3236,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3174,6 +3236,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[4].shape, ()) self.assertEqual(res[4].shape, ())
self.assertEqual(res[5].shape, ()) self.assertEqual(res[5].shape, ())
@prog_scope()
def test_static_nn_prelu(self): def test_static_nn_prelu(self):
x1 = paddle.full([], 1.0, 'float32') x1 = paddle.full([], 1.0, 'float32')
x1.stop_gradient = False x1.stop_gradient = False
...@@ -3234,6 +3297,57 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3234,6 +3297,57 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
np.testing.assert_allclose(res[3], np.array(1.0)) np.testing.assert_allclose(res[3], np.array(1.0))
@prog_scope()
def test_numel(self):
# 1) x is 0D
x = paddle.full([], 0.5)
out = paddle.numel(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
np.testing.assert_array_equal(res[0], np.array(1))
# 2) x is ND
x = paddle.full([3, 5], 0.5)
out = paddle.numel(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
np.testing.assert_array_equal(res[0], np.array(15))
@prog_scope()
def test_rank(self):
# 1) x is 0D
x = paddle.full([], 0.5)
out = paddle.rank(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
np.testing.assert_array_equal(res[0], np.array(0))
# 1) x is ND
x = paddle.full([3, 5], 0.5)
out = paddle.rank(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
np.testing.assert_array_equal(res[0], np.array(2))
@prog_scope()
def _test_shape(self):
x = paddle.full([], 0.5)
out = paddle.shape(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
# 0-Size should be [ np.array([]) ], its [None] now
self.assertEqual(res[0].shape, (0))
np.testing.assert_array_equal(res[0], np.array([]))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
...@@ -605,10 +605,7 @@ def batch_norm_orig2prim( ...@@ -605,10 +605,7 @@ def batch_norm_orig2prim(
@REGISTER_ORIG2PRIM('size') @REGISTER_ORIG2PRIM('size')
def size_orig2prim(op, x): def size_orig2prim(op, x):
# TODO(zhouwei): will change shape [1] to [] to support zero-dim return fill_const(functools.reduce(operator.mul, x.shape), (), paddle.int64)
return fill_const(
functools.reduce(operator.mul, x.shape), (1,), paddle.int64
)
# Register prim2orig lower rules # Register prim2orig lower rules
......
...@@ -1108,8 +1108,8 @@ class Optimizer: ...@@ -1108,8 +1108,8 @@ class Optimizer:
else: else:
assert isinstance(callbacks, list) assert isinstance(callbacks, list)
program = loss.block.program program = loss.block.program
assert len(loss.shape) == 1 and loss.shape[0] == 1, ( assert np.prod(loss.shape) == 1, (
"The loss.shape should be (1L,), but the current loss.shape is {}. " "The number of elements of loss should be 1, but the current loss.shape is {}, whose number of elements is not 1. "
"Maybe that you should call paddle.mean to process the current loss.".format( "Maybe that you should call paddle.mean to process the current loss.".format(
loss.shape loss.shape
) )
......
...@@ -89,7 +89,7 @@ def train_lenet(lenet, reader, optimizer): ...@@ -89,7 +89,7 @@ def train_lenet(lenet, reader, optimizer):
lenet.clear_gradients() lenet.clear_gradients()
if batch_id % 100 == 0: if batch_id % 100 == 0:
loss_list.append(avg_loss.numpy()[0]) loss_list.append(float(avg_loss))
_logger.info('{}: {}'.format('loss', avg_loss.numpy())) _logger.info('{}: {}'.format('loss', avg_loss.numpy()))
return loss_list return loss_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册