未验证 提交 76fb95fe 编写于 作者: W wangchaochaohu 提交者: GitHub

avoid data transform for linspace OP (#27444)

上级 a0452475
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/linspace_op.h" #include "paddle/fluid/operators/linspace_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,7 +22,7 @@ class LinspaceOp : public framework::OperatorWithKernel { ...@@ -21,7 +22,7 @@ class LinspaceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace"); OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace"); OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace"); OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace");
...@@ -50,11 +51,17 @@ class LinspaceOp : public framework::OperatorWithKernel { ...@@ -50,11 +51,17 @@ class LinspaceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return expected_kernel_type;
}
}; };
class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -1453,11 +1453,14 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1453,11 +1453,14 @@ def linspace(start, stop, num, dtype=None, name=None):
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable): if not isinstance(start, Variable):
tensor_start = fill_constant([1], dtype, start) with device_guard("cpu"):
tensor_start = fill_constant([1], dtype, start)
if not isinstance(stop, Variable): if not isinstance(stop, Variable):
tensor_stop = fill_constant([1], dtype, stop) with device_guard("cpu"):
tensor_stop = fill_constant([1], dtype, stop)
if not isinstance(num, Variable): if not isinstance(num, Variable):
tensor_num = fill_constant([1], 'int32', num) with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num)
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype', return core.ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype) dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册