未验证 提交 9e85d023 编写于 作者: Y Yiqun Liu 提交者: GitHub

Avoid crash when calling ctx->HasInputs and add the check of shape in fill_copnstant op. (#23698)

上级 ac4da77a
......@@ -721,6 +721,9 @@ CompileTimeInferShapeContext::CompileTimeInferShapeContext(
: op_(op), block_(block) {}
bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
if (op_.Inputs().find(name) == op_.Inputs().end()) {
return false;
}
const std::vector<std::string> &input_names = op_.Input(name);
auto length = input_names.size();
if (length == 0) {
......@@ -734,6 +737,9 @@ bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
}
bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
if (op_.Outputs().find(name) == op_.Outputs().end()) {
return false;
}
const std::vector<std::string> &output_names = op_.Output(name);
auto length = output_names.size();
if (length == 0) {
......@@ -747,6 +753,9 @@ bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
}
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
if (op_.Inputs().find(name) == op_.Inputs().end()) {
return false;
}
const std::vector<std::string> &input_names = op_.Input(name);
if (input_names.empty()) {
return false;
......@@ -758,6 +767,9 @@ bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
}
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
if (op_.Outputs().find(name) == op_.Outputs().end()) {
return false;
}
const std::vector<std::string> &output_names = op_.Output(name);
if (output_names.empty()) {
return false;
......
......@@ -25,6 +25,16 @@ class FillConstantOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FillConstant");
auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) {
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE_GE(
shape[i], 0,
platform::errors::InvalidArgument(
"Each value of attribute 'shape' is expected to be greater "
"than 0. But recieved: shape[%u] = %d; shape = [%s].",
i, shape[i], framework::make_ddim(shape)));
}
}
if (shape.empty() && ctx->HasInput("ShapeTensor")) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
......
......@@ -369,11 +369,11 @@ class SeqPGAgent(object):
self.probs, self.samples, self.sample_length = self.model(
source, source_length, target, target_length)
self.samples.stop_gradient = True
self.reward = fluid.layers.create_global_var(
self.reward = fluid.data(
name="reward",
shape=[-1, -1], # batch_size, seq_len
value="1",
shape=[None, None], # batch_size, seq_len
dtype=self.probs.dtype)
self.samples.stop_gradient = False
self.cost = self.alg.learn(self.probs, self.samples, self.reward,
self.sample_length)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册