未验证 提交 9f9cd919 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] Support paddle.max output 0D, test=allcase (#53242)

上级 ddd72039
...@@ -54,9 +54,10 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { ...@@ -54,9 +54,10 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(reduce_max, DECLARE_INFER_SHAPE_FUNCTOR(
ReduceMaxInferShapeFunctor, reduce_max,
PD_INFER_META(phi::OriginReduceInferMetaBase)); ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
REGISTER_OPERATOR( REGISTER_OPERATOR(
reduce_max, reduce_max,
......
...@@ -1335,7 +1335,7 @@ void max_grad(const Tensor& x, ...@@ -1335,7 +1335,7 @@ void max_grad(const Tensor& x,
} else { } else {
auto axis_ = std::vector<int64_t>(); auto axis_ = std::vector<int64_t>();
if (reduce_all) { if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) { for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i); axis_.push_back(i);
} }
} else { } else {
......
...@@ -744,7 +744,7 @@ ...@@ -744,7 +744,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false) args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : OriginReduceInferMeta func : ReduceIntArrayAxisInferMeta
kernel : kernel :
func : max func : max
backward : max_grad backward : max_grad
......
...@@ -89,6 +89,7 @@ PD_REGISTER_KERNEL(add_n, ...@@ -89,6 +89,7 @@ PD_REGISTER_KERNEL(add_n,
double, double,
int, int,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(add_n_array, PD_REGISTER_KERNEL(add_n_array,
...@@ -99,4 +100,5 @@ PD_REGISTER_KERNEL(add_n_array, ...@@ -99,4 +100,5 @@ PD_REGISTER_KERNEL(add_n_array,
double, double,
int, int,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {} int64_t) {}
...@@ -395,6 +395,7 @@ template struct SelectedRowsAddToTensor<phi::CPUContext, float>; ...@@ -395,6 +395,7 @@ template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>; template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>; template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>; template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::float16>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>; template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
......
...@@ -105,19 +105,19 @@ inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims, ...@@ -105,19 +105,19 @@ inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims,
inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims, inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
const DDim& in_dims) { const DDim& in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size()); int output_rank = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size(); int cur_output_rank = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0); std::vector<int64_t> output_shape(output_rank, 0);
// Validity Check: rank range. // Validity Check: rank range.
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
output_size, output_rank,
6, 6,
phi::errors::InvalidArgument("The output " phi::errors::InvalidArgument("The output "
"tensor's rank should be less than 6.")); "tensor's rank should be less than 6."));
for (int axis : unsqz_dims) { for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis; int cur = axis < 0 ? axis + cur_output_rank + 1 : axis;
// Vaildity Check: the axis bound // Vaildity Check: the axis bound
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
cur, cur,
...@@ -125,12 +125,12 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims, ...@@ -125,12 +125,12 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
phi::errors::InvalidArgument("The insert dimension value should " phi::errors::InvalidArgument("The insert dimension value should "
"not be less than 0")); "not be less than 0"));
PADDLE_ENFORCE_LE(cur, PADDLE_ENFORCE_LE(cur,
cur_output_size, cur_output_rank,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The insert dimension value shoule not be larger " "The insert dimension value shoule not be larger "
"than the dimension size of input tensor")); "than the dimension size of input tensor"));
// Move old axis, and insert new axis // Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) { for (int i = cur_output_rank; i >= cur; --i) {
if (output_shape[i] == 1) { if (output_shape[i] == 1) {
// Move axis // Move axis
output_shape[i + 1] = 1; output_shape[i + 1] = 1;
...@@ -139,11 +139,11 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims, ...@@ -139,11 +139,11 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
} }
output_shape[cur] = 1; output_shape[cur] = 1;
// Add the output size. // Add the output size.
cur_output_size++; cur_output_rank++;
} }
// Make output shape // Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { for (int in_idx = 0, out_idx = 0; out_idx < output_rank; ++out_idx) {
if (output_shape[out_idx] == 0) { if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++]; output_shape[out_idx] = in_dims[in_idx++];
} }
......
...@@ -102,8 +102,10 @@ void ReduceKernel(const Context& dev_ctx, ...@@ -102,8 +102,10 @@ void ReduceKernel(const Context& dev_ctx,
reduction_p->execute(astream, reduction_args); reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
out->set_mem_desc( const auto reshape_dims = out->dims().size() != 0
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))); ? vectorize<int64_t>(out->dims())
: std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
} }
} }
......
...@@ -242,7 +242,7 @@ def unscale_method(self, optimizer): ...@@ -242,7 +242,7 @@ def unscale_method(self, optimizer):
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
) )
self._found_inf = is_found_inf.numpy()[0] self._found_inf = int(is_found_inf)
class MixPrecisionScaler: class MixPrecisionScaler:
......
...@@ -179,7 +179,7 @@ def monkey_patch_math_varbase(): ...@@ -179,7 +179,7 @@ def monkey_patch_math_varbase():
@property @property
def _size_(var): def _size_(var):
return np.prod(var.shape) return int(np.prod(var.shape))
@property @property
def _T_(var): def _T_(var):
......
...@@ -212,7 +212,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase): ...@@ -212,7 +212,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase):
adam_test.set_dict(opt_state) adam_test.set_dict(opt_state)
self.assertEqual( self.assertEqual(
adam_test._learning_rate.best_loss, adam_test._learning_rate.best_loss,
adam3._learning_rate.best_loss.numpy()[0], adam3._learning_rate.best_loss,
"best_loss is different before and after set_dict", "best_loss is different before and after set_dict",
) )
self.assertEqual( self.assertEqual(
...@@ -275,7 +275,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase): ...@@ -275,7 +275,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase):
t = lr() t = lr()
np.testing.assert_allclose( np.testing.assert_allclose(
t.numpy()[0].item(), right_result[i], rtol=1e-05 t.numpy().item(), right_result[i], rtol=1e-05
) )
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
...@@ -342,7 +342,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase): ...@@ -342,7 +342,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase):
right_result = step_decay( right_result = step_decay(
epoch, learning_rate, step_size, decay_rate epoch, learning_rate, step_size, decay_rate
) )
fluid_result = scheduler().numpy()[0] fluid_result = scheduler().numpy().item()
scheduler.epoch() scheduler.epoch()
self.assertAlmostEqual( self.assertAlmostEqual(
right_result, right_result,
...@@ -371,7 +371,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase): ...@@ -371,7 +371,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase):
for epoch in range(30): for epoch in range(30):
right_result = lambda_decay(epoch, learning_rate, lr_lambda) right_result = lambda_decay(epoch, learning_rate, lr_lambda)
fluid_result = scheduler().numpy()[0] fluid_result = scheduler().numpy().item()
scheduler.epoch() scheduler.epoch()
self.assertAlmostEqual( self.assertAlmostEqual(
right_result, right_result,
......
...@@ -208,7 +208,7 @@ class TestReduceOnPlateauDecay: ...@@ -208,7 +208,7 @@ class TestReduceOnPlateauDecay:
self.assertEqual( self.assertEqual(
scheduler.cooldown_counter, scheduler1.cooldown_counter scheduler.cooldown_counter, scheduler1.cooldown_counter
) )
self.assertEqual(scheduler.best.numpy()[0], scheduler1.best) self.assertEqual(scheduler.best, scheduler1.best)
self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs) self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch) self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr) self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
......
...@@ -219,17 +219,19 @@ class TestReduceAPI(unittest.TestCase): ...@@ -219,17 +219,19 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0)) np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))
# 2) x is ND
if api in [ if api in [
paddle.sum, paddle.sum,
paddle.mean, paddle.mean,
paddle.nanmean, paddle.nanmean,
paddle.nansum, paddle.nansum,
paddle.max,
]: ]:
return return
x = paddle.rand([3, 5]) # 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([3, 5])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, None)
out.retain_grads() out.retain_grads()
...@@ -240,6 +242,21 @@ class TestReduceAPI(unittest.TestCase): ...@@ -240,6 +242,21 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [3, 5]) self.assertEqual(x.grad.shape, [3, 5])
# 3) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, 0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [5])
paddle.enable_static() paddle.enable_static()
def test_static_reduce(self): def test_static_reduce(self):
...@@ -284,16 +301,19 @@ class TestReduceAPI(unittest.TestCase): ...@@ -284,16 +301,19 @@ class TestReduceAPI(unittest.TestCase):
np.testing.assert_allclose(res[2], np.array(1.0)) np.testing.assert_allclose(res[2], np.array(1.0))
np.testing.assert_allclose(res[3], np.array(1.0)) np.testing.assert_allclose(res[3], np.array(1.0))
# 2) x is ND
if api in [ if api in [
paddle.sum, paddle.sum,
paddle.mean, paddle.mean,
paddle.nanmean, paddle.nanmean,
paddle.nansum, paddle.nansum,
paddle.max,
]: ]:
return return
# 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([3, 5])
x = paddle.rand([3, 5]) x = paddle.rand([3, 5])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, None)
...@@ -309,6 +329,25 @@ class TestReduceAPI(unittest.TestCase): ...@@ -309,6 +329,25 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (3, 5)) self.assertEqual(res[2].shape, (3, 5))
# 3) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, 0)
paddle.static.append_backward(out)
fetch_list = [out]
if block.has_var(x.grad_name):
fetch_list.extend([out.grad_name, x.grad_name])
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
if len(res) > 1:
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (5,))
paddle.disable_static() paddle.disable_static()
......
...@@ -81,8 +81,13 @@ class ProgressBar: ...@@ -81,8 +81,13 @@ class ProgressBar:
for i, (k, val) in enumerate(values): for i, (k, val) in enumerate(values):
if k == "loss": if k == "loss":
val = val if isinstance(val, (list, np.ndarray)) else [val] if isinstance(val, list):
if isinstance(val[0], np.uint16): scalar_val = val[0]
elif isinstance(val, np.ndarray):
scalar_val = val.item()
else:
scalar_val = val
if isinstance(scalar_val, np.uint16):
values[i] = ("loss", list(convert_uint16_to_float(val))) values[i] = ("loss", list(convert_uint16_to_float(val)))
if current_num: if current_num:
......
...@@ -700,7 +700,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -700,7 +700,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var = paddle.add_n(global_norm_var) global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = paddle.sqrt(global_norm_var) global_norm_var = paddle.sqrt(global_norm_var)
max_global_norm = paddle.full( max_global_norm = paddle.full(
shape=[1], dtype=global_norm_var.dtype, fill_value=self.clip_norm shape=[], dtype=global_norm_var.dtype, fill_value=self.clip_norm
) )
need_clip = False need_clip = False
......
...@@ -178,7 +178,7 @@ class FakeQuantActLSQPlus(Layer): ...@@ -178,7 +178,7 @@ class FakeQuantActLSQPlus(Layer):
s_attr = ParamAttr( s_attr = ParamAttr(
name=self._scale_name, initializer=Constant(1.0), trainable=True name=self._scale_name, initializer=Constant(1.0), trainable=True
) )
self.s = self.create_parameter(shape=[1], attr=s_attr, dtype='float32') self.s = self.create_parameter(shape=[], attr=s_attr, dtype='float32')
self.s.stop_gradient = False self.s.stop_gradient = False
if not self.symmetric: if not self.symmetric:
...@@ -189,7 +189,7 @@ class FakeQuantActLSQPlus(Layer): ...@@ -189,7 +189,7 @@ class FakeQuantActLSQPlus(Layer):
name=self._beta_name, initializer=Constant(0.0), trainable=True name=self._beta_name, initializer=Constant(0.0), trainable=True
) )
self.beta = self.create_parameter( self.beta = self.create_parameter(
shape=[1], attr=beta_attr, dtype='float32' shape=[], attr=beta_attr, dtype='float32'
) )
self.beta.stop_gradient = False self.beta.stop_gradient = False
......
...@@ -26,10 +26,7 @@ from paddle.incubate.autograd.utils import as_tensors ...@@ -26,10 +26,7 @@ from paddle.incubate.autograd.utils import as_tensors
# Finite Difference Utils # Finite Difference Utils
########################################################## ##########################################################
def _product(t): def _product(t):
if isinstance(t, int): return int(np.product(t))
return t
else:
return np.product(t)
def _get_item(t, idx): def _get_item(t, idx):
......
...@@ -407,7 +407,7 @@ class BaseModel(paddle.nn.Layer): ...@@ -407,7 +407,7 @@ class BaseModel(paddle.nn.Layer):
parent_ids = [] parent_ids = []
for step_idx in range(paddle.to_tensor(self.beam_max_step_num)): for step_idx in range(paddle.to_tensor(self.beam_max_step_num)):
if paddle.sum(1 - beam_finished).numpy()[0] == 0: if paddle.sum(1 - beam_finished) == 0:
break break
step_input = self._merge_batch_beams(step_input) step_input = self._merge_batch_beams(step_input)
new_dec_hidden, new_dec_cell = [], [] new_dec_hidden, new_dec_cell = [], []
......
...@@ -28,7 +28,7 @@ from paddle.static import InputSpec ...@@ -28,7 +28,7 @@ from paddle.static import InputSpec
def for_in_range(x): def for_in_range(x):
z = paddle.tensor.fill_constant([1], 'int32', 0) z = paddle.tensor.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]): for i in range(x.numpy().item()):
z = z + i z = z + i
return z return z
......
...@@ -342,7 +342,7 @@ def train(args, to_static): ...@@ -342,7 +342,7 @@ def train(args, to_static):
model.train() model.train()
avg_cost, prediction, acc = model(doc, label) avg_cost, prediction, acc = model(doc, label)
loss_data.append(avg_cost.numpy()[0]) loss_data.append(float(avg_cost))
avg_cost.backward() avg_cost.backward()
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
...@@ -358,7 +358,7 @@ def train(args, to_static): ...@@ -358,7 +358,7 @@ def train(args, to_static):
"step: %d, ave loss: %f, speed: %f steps/s" "step: %d, ave loss: %f, speed: %f steps/s"
% ( % (
batch_id, batch_id,
avg_cost.numpy()[0], float(avg_cost),
args.log_step / used_time, args.log_step / used_time,
) )
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册