未验证 提交 76fb95fe 编写于 作者: W wangchaochaohu 提交者: GitHub

avoid data transform for linspace OP (#27444)

上级 a0452475
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/linspace_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -21,7 +22,7 @@ class LinspaceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace");
......@@ -50,11 +51,17 @@ class LinspaceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return expected_kernel_type;
}
};
class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -1453,11 +1453,14 @@ def linspace(start, stop, num, dtype=None, name=None):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable):
tensor_start = fill_constant([1], dtype, start)
with device_guard("cpu"):
tensor_start = fill_constant([1], dtype, start)
if not isinstance(stop, Variable):
tensor_stop = fill_constant([1], dtype, stop)
with device_guard("cpu"):
tensor_stop = fill_constant([1], dtype, stop)
if not isinstance(num, Variable):
tensor_num = fill_constant([1], 'int32', num)
with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num)
if in_dygraph_mode():
return core.ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册