未验证 提交 11061325 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix softmax_with_cross_entropy in dygraph, test=develop (#36297)

上级 64d08c0e
......@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -54,8 +50,7 @@ class SoftmaxWithCrossEntropyOpMaker
"exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
"where labels is ont-hot."
"Currently, the tensor is generated and used in npu kernel only. ")
.AsIntermediate()
.AsDispensable();
.AsIntermediate();
#endif
AddOutput("Loss",
"(Tensor, default: Tensor<float>), A tensor in same shape with "
......@@ -136,6 +131,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true,
platform::errors::InvalidArgument(
"Output(Softmax) should be not null."));
#ifdef PADDLE_WITH_ASCEND_CL
PADDLE_ENFORCE_EQ(ctx->HasOutput("Backprop"), true,
platform::errors::InvalidArgument(
"Output(Backprop) should be not null."));
#endif
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Loss"), true,
platform::errors::InvalidArgument("Output(Loss) should be not null."));
......@@ -225,6 +225,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true,
platform::errors::InvalidArgument(
"Input(Softmax) should be not null."));
#ifdef PADDLE_WITH_ASCEND_CL
PADDLE_ENFORCE_EQ(ctx->HasInput("Backprop"), true,
platform::errors::InvalidArgument(
"Input(Backprop) should be not null."));
#endif
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::InvalidArgument("Input(Label) should be not null."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册