diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index c4efc4ab72d6308d89220ef9e49ad2240b2ddef2..acda31e0f2309bed59b1d31b39abcacf21d72b4e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -144,7 +144,20 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { // skip out auto *out = dout; - if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { + // Special case when dy is not needed and dx doesn't reduce + if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) { + VLOG(4) << "Special case when dy is not needed and dx doesn't " + "reduce"; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dx); + } else if (dx == nullptr && dy != nullptr && dy->dims() == dout->dims()) { + VLOG(4) << "Special case when dx is not needed and dy doesn't " + "reduce"; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dy); + } else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } else { default_elementwise_add_grad(ctx, x, y, out, dout, dx, diff --git a/paddle/fluid/operators/gelu_op.h b/paddle/fluid/operators/gelu_op.h index 329b8583192a41c6c088cdbbb3ee7bd68c77f373..936da8dee85fcf585e72c48565d057ea31204d14 100644 --- a/paddle/fluid/operators/gelu_op.h +++ b/paddle/fluid/operators/gelu_op.h @@ -36,10 +36,22 @@ struct GeluFunctor { void operator()(Device d, X x, Out out, bool approximate) const { if (approximate) { // gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) - auto temp = (static_cast(M_2_SQRTPI * M_SQRT1_2) * - (x + static_cast(0.044715) * x.cube())) - .tanh(); - out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); + if (std::is_same::value) { + VLOG(4) << "cast from float16 to float before computing"; + auto casted_x = x.template cast(); + auto temp = + (static_cast(M_2_SQRTPI * M_SQRT1_2) * + (casted_x + static_cast(0.044715) * casted_x.cube())) + .tanh(); + out.device(d) = (casted_x * static_cast(0.5) * + (static_cast(1) + temp)) + .template cast(); + } else { + auto temp = (static_cast(M_2_SQRTPI * M_SQRT1_2) * + (x + static_cast(0.044715) * x.cube())) + .tanh(); + out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); + } } else { #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) @@ -60,8 +72,17 @@ struct GeluFunctor { } #else // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - auto temp = (x * static_cast(M_SQRT1_2)).erf(); - out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); + if (std::is_same::value) { + VLOG(4) << "cast from float16 to float before computing"; + auto casted_x = x.template cast(); + auto temp = (casted_x * static_cast(M_SQRT1_2)).erf(); + out.device(d) = (casted_x * static_cast(0.5) * + (static_cast(1) + temp)) + .template cast(); + } else { + auto temp = (x * static_cast(M_SQRT1_2)).erf(); + out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); + } #endif } } @@ -72,13 +93,32 @@ struct GeluGradFunctor { template void operator()(Device d, X x, dOut dout, dX dx, bool approximate) const { if (approximate) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); - const auto y = - (kAlpha * ((static_cast(0.044715) * x.cube()) + x)).tanh(); - dx.device(d) = static_cast(0.5) * dout * - (static_cast(1) + y + - (x - x * y.square()) * (kAlpha + kBeta * x.square())); + if (std::is_same::value) { + VLOG(4) << "cast from float16 to float before computing"; + auto casted_x = x.template cast(); + auto casted_dout = dout.template cast(); + + const float kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const float kBeta = + kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = + (kAlpha * + ((static_cast(0.044715) * casted_x.cube()) + casted_x)) + .tanh(); + dx.device(d) = (static_cast(0.5) * casted_dout * + (static_cast(1) + y + + (casted_x - casted_x * y.square()) * + (kAlpha + kBeta * casted_x.square()))) + .template cast(); + } else { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = + (kAlpha * ((static_cast(0.044715) * x.cube()) + x)).tanh(); + dx.device(d) = static_cast(0.5) * dout * + (static_cast(1) + y + + (x - x * y.square()) * (kAlpha + kBeta * x.square())); + } } else { #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) @@ -117,13 +157,26 @@ struct GeluGradFunctor { #else // gelu_grad(x) = dout * 0.5 * (1 + erf(x / sqrt(2)) + x * sqrt(2 / pi) * // exp(- x^2 / 2) - auto first = - static_cast(0.5) * - (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); - - auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * - (-static_cast(0.5) * x.square()).exp(); - dx.device(d) = dout * (first + second); + if (std::is_same::value) { + VLOG(4) << "cast from float16 to float before computing"; + auto casted_x = x.template cast(); + auto casted_dout = dout.template cast(); + auto first = static_cast(0.5) * + (static_cast(1) + + ((casted_x * static_cast(M_SQRT1_2)).erf())); + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * + casted_x * + (-static_cast(0.5) * casted_x.square()).exp(); + dx.device(d) = (casted_dout * (first + second)).template cast(); + } else { + auto first = + static_cast(0.5) * + (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); + + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * + (-static_cast(0.5) * x.square()).exp(); + dx.device(d) = dout * (first + second); + } #endif } } diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index fd2a1e70e2cf0655aab4b663649c15c193465566..74ee233612b3705f2b08b464710ab8998a19b4a2 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1476,6 +1476,9 @@ class Dropout(layers.Layer): self._is_test = is_test def forward(self, input): + # fast return for p == 0 + if self._dropout_prob == 0: + return input prog = default_main_program() if (self._seed is None or self._seed == 0) and prog.random_seed != 0: self._seed = prog.random_seed diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 8f5fdf52d95ef99cd8afe5f70535182166f7f5a8..96947bf72c7ddf299a4f4b372be3d62de4aaa1b5 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -179,6 +179,7 @@ def monkey_patch_variable(): outputs={"Out": [out]}, attrs={"in_dtype": self.dtype, "out_dtype": out.dtype}) + out.stop_gradient = self.stop_gradient return out def _scalar_op_(var, scale, bias): diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 73ecc5cdf77a03f3eaa4af9f5cd58ea9909d6382..fb246993073b182384c4b148716b50db313ef0d7 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1018,6 +1018,9 @@ def dropout(x, x = fluid.data(name="data", shape=[None, 32, 32], dtype="float32") dropped = fluid.layers.dropout(x, dropout_prob=0.5) """ + # fast return for p == 0 + if dropout_prob == 0: + return x def get_attrs(prog, dropout_prob, is_test, seed): if (seed is None or seed == 0) and prog.random_seed != 0: diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 0b70010bbee05c91fd26d15f3eae46e9df597951..563933f8cd2e800feb9b6753d960ae3bf012eebf 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -224,6 +224,11 @@ def cast(x, dtype): x = paddle.to_tensor([2, 3, 4], 'float64') y = paddle.cast(x, 'uint8') """ + if in_dygraph_mode(): + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + out = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) + check_variable_and_dtype( x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], @@ -234,7 +239,8 @@ def cast(x, dtype): ], 'cast') helper = LayerHelper('cast', **locals()) - out = helper.create_variable_for_type_inference(dtype=dtype) + out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=x.stop_gradient) helper.append_op( type='cast', inputs={'X': [x]}, diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 0d0273c1670fa7152cc593d690fd6a9b3a13522e..ba2abd72500788c4bbacf3c12d4ba711da1b01f3 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -302,13 +302,16 @@ class TestDropoutFAPI(unittest.TestCase): training=False, mode='downscale_in_infer') res10 = paddle.nn.functional.dropout(x=input, p=1., training=True) + res11 = paddle.fluid.layers.dropout(x=input, dropout_prob=0.) in_np = np.random.random([40, 40]).astype("float32") res_np = in_np res_np2 = np.zeros_like(in_np) exe = fluid.Executor(place) - res_list = [res1, res2, res3, res4, res5, res6, res7, res8, res9] + res_list = [ + res1, res2, res3, res4, res5, res6, res7, res8, res9, res11 + ] for res in res_list: fetches = exe.run(fluid.default_main_program(), feed={"input": in_np}, @@ -383,8 +386,12 @@ class TestDropoutFAPI(unittest.TestCase): mode='downscale_in_infer') res10 = paddle.nn.functional.dropout( x=input, p=1., training=True) + dropout = paddle.fluid.dygraph.Dropout(p=0, ) + res11 = dropout(input) - res_list = [res1, res2, res3, res4, res5, res6, res7, res8, res9] + res_list = [ + res1, res2, res3, res4, res5, res6, res7, res8, res9, res11 + ] for res in res_list: self.assertTrue(np.allclose(res.numpy(), res_np)) self.assertTrue(np.allclose(res10.numpy(), res_np2)) diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index 76e371b216778f9b38807a7cebf7a5f717a6d044..fc5e613decddea2f7e2cd5a0e5b672d9bbd8dcfb 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -257,6 +257,19 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(-a_np, b_np)) + @prog_scope() + def test_astype(self): + a = fluid.layers.data(name="a", shape=[10, 1]) + b = a.astype('float32') + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.uniform(-1, 1, size=[10, 1]).astype('float64') + + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(a_np.astype('float32'), b_np)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index b3bdf1e95cc75dfeb658ad0cbd5303aaeb8f953e..7319b860db8f79262c6f5ab307bed15145472ce7 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -887,6 +887,10 @@ def dropout(x, print(y_01) """ + # fast return for p == 0 + if p == 0: + return x + if not isinstance(p, (float, int)): raise TypeError("p argument should be a number") if p < 0 or p > 1: