diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 8048311fe54ee1827fb5b91577478a1d30803e43..44dea97ef0eeee28b717aed8bf1c5008fd6f3738 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -16,8 +16,7 @@ static int run_cnt = 0; class TestOp : public OperatorBase { public: - void InferShape( - const std::shared_ptr& scope) const override { + void InferShape(const std::shared_ptr& scope) const override { ++infer_shape_cnt; } void Run(const std::shared_ptr& scope, diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3a1ffc02151f42a4fe6f103925ab424251ee8d85..9bf60b7b11636df97031d111031ef782a173006b 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -20,7 +20,7 @@ namespace paddle { namespace framework { template <> -Eigen::DefaultDevice* KernelContext::GetEigenDevice< +Eigen::DefaultDevice* ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { return device_context_.get_eigen_device(); } @@ -28,7 +28,7 @@ Eigen::DefaultDevice* KernelContext::GetEigenDevice< #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice* -KernelContext::GetEigenDevice() const { +ExecutionContext::GetEigenDevice() const { return device_context_.get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0a8c82ee47521713fa96cb423ceca4de858c260c..ef1521b83bb50774d7b4f710a5deff879373ba35 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -31,22 +31,9 @@ limitations under the License. */ namespace paddle { namespace framework { -template -struct EigenDeviceConverter; - -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::DefaultDevice; -}; - -#ifndef PADDLE_ONLY_CPU -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::GpuDevice; -}; -#endif - class OperatorBase; +class InferShapeContext; +class ExecutionContext; /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -112,46 +99,127 @@ class OperatorBase { std::shared_ptr> in_out_idxs_; }; -class KernelContext { +class OperatorContext { public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} + OperatorContext(const OperatorBase* op, const std::shared_ptr& scope) + : op_(*op), scope_(scope) {} + + size_t InputSize() const { return op_.inputs_.size(); } - const Variable* Input(int index) const { - return scope_->GetVariable(op_.inputs_[index]); + size_t OutputSize() const { return op_.outputs_.size(); } + + const Variable* InputVar(const size_t& index) const { + return scope_->GetVariable(op_.inputs_.at(index)); } - Variable* Output(int index) const { - return scope_->GetVariable(op_.outputs_[index]); + Variable* OutputVar(const size_t& index) const { + return scope_->GetVariable(op_.outputs_.at(index)); } - const Variable* Input(const std::string& name) const { + const Variable* InputVar(const std::string& name) const { return scope_->GetVariable(op_.Input(name)); } - const Variable* Output(const std::string& name) const { + Variable* OutputVar(const std::string& name) const { return scope_->GetVariable(op_.Output(name)); } - const std::vector Inputs(const std::string& name) const { + const std::vector MultiInputVar( + const std::string& name) const { auto names = op_.Inputs(name); std::vector res; + res.reserve(names.size()); std::transform( - names.begin(), names.end(), res.begin(), + names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) { return scope_->GetVariable(name); }); return res; } - const std::vector Outputs(const std::string& name) const { + std::vector MultiOutputVar(const std::string& name) const { auto names = op_.Outputs(name); std::vector res; + res.reserve(names.size()); std::transform( - names.begin(), names.end(), res.begin(), + names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) { return scope_->GetVariable(name); }); return res; } + template + const T* Input(const size_t& index) const { + return &(InputVar(index)->Get()); + } + + template + T* Output(const size_t& index) const { + return OutputVar(index)->GetMutable(); + } + + template + const T* Input(const std::string& name) const { + return &(InputVar(name)->Get()); + } + + template + T* Output(const std::string& name) const { + return OutputVar(name)->GetMutable(); + } + + template + const std::vector MultiInput(const std::string& name) const { + auto names = op_.Inputs(name); + std::vector res; + res.reserve(names.size()); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return &scope_->GetVariable(name)->Get(); + }); + return res; + } + + template + std::vector MultiOutput(const std::string& name) const { + auto names = op_.Outputs(name); + std::vector res; + res.reserve(names.size()); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return scope_->GetVariable(name)->GetMutable(); + }); + return res; + } + + const OperatorBase& op_; + const std::shared_ptr& scope_; +}; + +class InferShapeContext : public OperatorContext { + public: + InferShapeContext(const OperatorBase* op, const std::shared_ptr& scope) + : OperatorContext(op, scope) {} +}; + +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + +class ExecutionContext : public OperatorContext { + public: + ExecutionContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : OperatorContext(op, scope), device_context_(device_context) {} + template ::EigenDeviceType> @@ -159,38 +227,23 @@ class KernelContext { platform::Place GetPlace() const { return device_context_.GetPlace(); } - const OperatorBase& op_; - const std::shared_ptr& scope_; const platform::DeviceContext& device_context_; }; class OpKernel { public: /** - * KernelContext is the only parameter of Kernel Run function. + * ExecutionContext is the only parameter of Kernel Run function. * Run will get input/output variables, state such as momentum and * device resource such as CUDA stream, cublas handle, etc. from - * KernelContext. User should construct it before run the Operator. + * ExecutionContext. User should construct it before run the Operator. */ - virtual void Compute(const KernelContext& context) const = 0; + virtual void Compute(const ExecutionContext& context) const = 0; virtual ~OpKernel() {} }; -template -struct VarToTensor {}; - -template <> -struct VarToTensor { - Tensor* operator()(Variable* var) { return var->GetMutable(); } -}; - -template <> -struct VarToTensor { - const Tensor* operator()(Variable* var) { return &var->Get(); } -}; - class OperatorWithKernel : public OperatorBase { public: struct OpKernelKey { @@ -216,10 +269,14 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; + void InferShape(const std::shared_ptr& scope) const { + InferShape(InferShapeContext(this, scope)); + } + void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(KernelContext(this, scope, dev_ctx)); + opKernel->Compute(ExecutionContext(this, scope, dev_ctx)); } static std::unordered_map& @@ -228,34 +285,8 @@ class OperatorWithKernel : public OperatorBase { return g_all_op_kernels; } - void InferShape(const std::shared_ptr& scope) const final { - std::vector ins; - VarNamesToTensors(scope, inputs_, &ins); - std::vector outs; - VarNamesToTensors(scope, outputs_, &outs); - InferShape(ins, outs); - }; - - private: - template - void VarNamesToTensors(const std::shared_ptr& scope, - const std::vector& var_names, - std::vector* container) const { - container->reserve(var_names.size()); - VarToTensor convert; - for (auto& name : var_names) { - auto var = scope->GetVariable(name); - if (var != nullptr) { - container->push_back(convert(var)); - } else { - container->push_back(nullptr); - } - } - } - protected: - virtual void InferShape(const std::vector& inputs, - const std::vector& outputs) const = 0; + virtual void InferShape(const InferShapeContext& ctx) const = 0; }; } // namespace framework diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 3fae356c3e5d5b44271440b66d6923fd4994b937..daa3645b4d7588baffc57491d0a7f7f6368eda7b 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -24,7 +24,8 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: void Init() override { x = 1; } - void InferShape(const std::shared_ptr& scope) const override {} + void InferShape( + const std::shared_ptr& scope) const override {} void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const override { op_run_num++; @@ -73,6 +74,7 @@ TEST(OperatorBase, all) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); scope->CreateVariable("OUT1"); ASSERT_EQ(paddle::framework::op_run_num, 0); + op->InferShape(scope); op->Run(scope, device_context); ASSERT_EQ(paddle::framework::op_run_num, 1); } @@ -97,14 +99,13 @@ static int cpu_kernel_run_num = 0; class OpWithKernelTest : public OperatorWithKernel { protected: - void InferShape(const std::vector& inputs, - const std::vector& outputs) const override {} + void InferShape(const framework::InferShapeContext& ctx) const override {} }; template class CPUKernelTest : public OpKernel { public: - void Compute(const KernelContext& ctx) const { + void Compute(const ExecutionContext& ctx) const { std::cout << "this is cpu kernel" << std::endl; std::cout << ctx.op_.DebugString() << std::endl; cpu_kernel_run_num++; @@ -117,7 +118,8 @@ class CPUKernelTest : public OpKernel { class OperatorMultiInputsTest : public OperatorBase { public: void Init() override { x = 1; } - void InferShape(const std::shared_ptr& scope) const override {} + void InferShape( + const std::shared_ptr& scope) const override {} void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const override { ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); @@ -149,13 +151,31 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker class CPUKernalMultiInputsTest : public OpKernel { public: - void Compute(const KernelContext& ctx) const { + void Compute(const ExecutionContext& ctx) const { auto xs = ctx.op_.Inputs("xs"); ASSERT_EQ(xs.size(), 3UL); ASSERT_EQ(xs[0], "x0"); ASSERT_EQ(xs[1], "x1"); ASSERT_EQ(xs[2], "x2"); + auto inVar0 = ctx.MultiInputVar("xs"); + ASSERT_EQ(inVar0.size(), 3); + + auto intVar1 = ctx.InputVar("k"); + ASSERT_NE(intVar1, nullptr); + + auto outVar0 = ctx.MultiOutputVar("ys"); + ASSERT_EQ(outVar0.size(), 2); + + auto inTensor0 = ctx.MultiInput("xs"); + ASSERT_EQ(inTensor0.size(), 3); + + auto intTensor1 = ctx.Input("k"); + ASSERT_NE(intTensor1, nullptr); + + auto outTensor0 = ctx.MultiOutput("ys"); + ASSERT_EQ(outTensor0.size(), 2); + auto k = ctx.op_.Input("k"); ASSERT_EQ(k, "k0"); @@ -233,6 +253,12 @@ TEST(OpKernel, multi_inputs) { paddle::platform::CPUDeviceContext cpu_device_context; auto scope = std::make_shared(); + scope->CreateVariable("x0")->GetMutable(); + scope->CreateVariable("x1")->GetMutable(); + scope->CreateVariable("x2")->GetMutable(); + scope->CreateVariable("k0")->GetMutable(); + scope->CreateVariable("y0")->GetMutable(); + scope->CreateVariable("y1")->GetMutable(); auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 1424b0284372d8dfe9eb93ee251b121a48b19b0b..3a43dbfbada87e458109d8ca22effdb4407b4c1d 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -19,16 +19,16 @@ namespace operators { class AddOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two"); - PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); - PADDLE_ENFORCE( - inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr, - "Inputs/Outputs of AddOp must all be set"); - PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr, + "Inputs of AddOp must all be set"); + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, + "Outputs of AddOp must all be set"); + PADDLE_ENFORCE(ctx.Input(0)->dims() == ctx.Input(1)->dims(), "Two input of Add Op's dimension must be same."); - outputs[0]->Resize(inputs[0]->dims()); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -49,8 +49,7 @@ The equation is: Out = X + Y class AddOpGrad : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override {} + void InferShape(const InferShapeContext &ctx) const override {} std::string DebugString() const override { LOG(INFO) << "AddOpGrad"; return ""; diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 0c39433788e1e07e30aaadc4766028219b05bfa5..d2b649fcbd1e5cac1c8cfcfd4e522e41135f7d1f 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -21,16 +21,17 @@ namespace operators { template class AddKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto input0 = context.Input(0)->Get(); - auto input1 = context.Input(1)->Get(); - auto output = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& context) const override { + auto input0 = context.Input(0); + auto input1 = context.Input(1); + auto output = context.Output(0); output->mutable_data(context.GetPlace()); EigenVector::Flatten(*output).device( *(context.GetEigenDevice())) = - EigenVector::Flatten(input0) + EigenVector::Flatten(input1); + framework::EigenVector::Flatten(*input0) + + framework::EigenVector::Flatten(*input1); } }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 46c88d4d1a28eeedd02eb699562244651ead6d68..4f5b935fde4d5b0d9efae66554cf890291e26941 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -19,20 +19,20 @@ namespace operators { class OnehotCrossEntropyOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2, + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of OnehotCrossEntropyOp must be two"); - PADDLE_ENFORCE(outputs.size() == 1, + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of OnehotCrossEntropyOp must be one"); - PADDLE_ENFORCE(inputs[0] != nullptr && inputs[1] != nullptr, + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr, "Inputs of OnehotCrossEntropyOp must all be set"); - PADDLE_ENFORCE(outputs[0] != nullptr, + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "Outputs of OnehotCrossEntropyOp must all be set"); - PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2."); - PADDLE_ENFORCE(outputs[0]->dims().size() == 1, + PADDLE_ENFORCE(ctx.Input(0)->dims().size() == 2, + "X's dimension must be 2."); + PADDLE_ENFORCE(ctx.Output(0)->dims().size() == 1, "label's dimension must be 1."); - outputs[0]->Resize({inputs[0]->dims()[0]}); + ctx.Output(0)->Resize({ctx.Input(0)->dims()[0]}); } }; diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 0383df46be3a3cea7dde8f1b45857e64d5a2f2d8..c3a3728149950a5c7f2195122e8e0ff728492bdb 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -23,18 +23,18 @@ class OnehotCrossEntropyOpKernel : public OpKernel { public: constexpr T LOG_THRESHOLD() const { return static_cast(1e-20); } - void Compute(const KernelContext& context) const override { - auto X = context.Input(0)->Get(); - const T* X_data = X.data(); - const int* label_data = context.Input(1)->Get().data(); - auto* Y = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& ctx) const override { + auto X = ctx.Input(0); + const T* X_data = X->data(); + const int* label_data = ctx.Input(1)->data(); + auto Y = ctx.Output(0); - Y->mutable_data(context.GetPlace()); + Y->mutable_data(ctx.GetPlace()); T* Y_data = Y->data(); - int batch_size = X.dims()[0]; - int class_num = X.dims()[1]; + int batch_size = X->dims()[0]; + int class_num = X->dims()[1]; // Y[i] = -log(X[i][j]) for (int i = 0; i < batch_size; ++i) { diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 22c1b78005358a934c57d487f5b0cff133f61f0c..d127f3a302a340fe7558f918d6eeb2ea0a3fafe7 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -19,18 +19,17 @@ namespace operators { class MulOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs"); - auto dim0 = inputs[0]->dims(); - auto dim1 = inputs[1]->dims(); + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); + auto dim0 = ctx.Input(0)->dims(); + auto dim1 = ctx.Input(1)->dims(); PADDLE_ENFORCE(dim0.size() == 2 && dim1.size() == 2, "The input of mul op must be matrix"); PADDLE_ENFORCE( dim0[1] == dim1[0], "First matrix's width must be equal with second matrix's height."); - PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output"); - outputs[0]->Resize({dim0[0], dim1[1]}); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "The mul op must take one output"); + ctx.Output(0)->Resize({dim0[0], dim1[1]}); } }; @@ -51,8 +50,7 @@ The equation is: Out = X * Y class MulOpGrad : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override {} + void InferShape(const InferShapeContext &ctx) const override {} std::string DebugString() const override { LOG(INFO) << "MulGrad"; return ""; diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 467975044638a3f034ceec84173e8d3fed43cc0c..eef72ab293e13a9d05ce0013be41ec4bb75d6077 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -22,19 +22,17 @@ namespace operators { template class MulKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { + void Compute(const ExecutionContext& context) const override { Eigen::array, 1> dim_pair = { {Eigen::IndexPair(1, 0)}}; - auto input0 = context.Input(0)->Get(); - auto input1 = context.Input(1)->Get(); - auto* output = context.Output(0)->GetMutable(); - + auto output = context.Output(0); output->mutable_data(context.GetPlace()); EigenMatrix::From(*output).device(*(context.GetEigenDevice())) = - EigenMatrix::From(input0).contract(EigenMatrix::From(input1), - dim_pair); + EigenMatrix::From(*context.Input("X")) + .contract(EigenMatrix::From(*context.Input("Y")), + dim_pair); } }; } // namespace operators diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 4129422fa744b2a7cf135b681efa73ffb2ebcdcc..2ad2b66c8f385c858eb34c7ea766f168de9c817e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -18,17 +18,17 @@ namespace operators { class RowWiseAddOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add"); - auto dim0 = inputs[0]->dims(); - auto dim1 = inputs[1]->dims(); + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2UL, + "Two inputs is needed by rowwise add"); + auto dim0 = ctx.Input(0)->dims(); + auto dim1 = ctx.Input(1)->dims(); PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); - PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1"); - outputs[0]->Resize(inputs[0]->dims()); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "The output size must be 1"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 4596925e9322f373c822608fd9aa6ecee6144d4c..b86dd5463436bf521f9939b1c421b39f11102769 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -21,14 +21,12 @@ namespace operators { template class RowWiseAddKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto in0 = context.Input(0)->Get(); - auto in1 = context.Input(1)->Get(); - auto* out = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& context) const override { + auto out = context.Output(0); out->mutable_data(context.GetPlace()); - auto input = EigenMatrix::From(in0); - auto bias = EigenVector::From(in1); + auto input = EigenMatrix::From(*context.Input(0)); + auto bias = EigenVector::From(*context.Input(1)); auto output = EigenMatrix::From(*out); const int bias_size = bias.dimension(0); diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index f6c654a9e7083704e353c276e0abc975f4e61ef9..9a84dc8af3b3e649b776ca8a97dedba1fa3ff48d 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -19,16 +19,15 @@ namespace operators { class SGDOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); - PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); - PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); - PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set"); - PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); - PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one"); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, "inputs[0] mast be set"); + PADDLE_ENFORCE(ctx.InputVar(1) != nullptr, "inputs[1] mast be set"); + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "outputs[0] mast be set"); + PADDLE_ENFORCE(ctx.Input(0)->dims() == ctx.Input(1)->dims(), "Two input of SGD Op's dimension must be same."); - outputs[0]->Resize(inputs[0]->dims()); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 65179d323bd991b8b4e196c069a11cd901c62082..af1dfdd756ceb9991bee6b85c3281c05f0fb5a9f 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -21,16 +21,16 @@ namespace operators { template class SGDOpKernel : public OpKernel { public: - void Compute(const KernelContext& ctx) const override { - auto param = ctx.Input("param")->Get(); - auto grad = ctx.Input("grad")->Get(); - auto* param_out = ctx.Output(0)->GetMutable(); + void Compute(const ExecutionContext& ctx) const override { + auto param = ctx.Input("param"); + auto grad = ctx.Input("grad"); + auto param_out = ctx.Output(0); float lr = ctx.op_.GetAttr("learning_rate"); param_out->mutable_data(ctx.GetPlace()); EigenVector::Flatten(*param_out).device(*(ctx.GetEigenDevice())) = - EigenVector::Flatten(param) - lr * EigenVector::Flatten(grad); + EigenVector::Flatten(*param) - lr * EigenVector::Flatten(*grad); } }; diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 716f1d9c4dbc45e2d5569f8d634b06fd988a149c..a81ab262cc6fe7bdff0045259e0030f3d46f503f 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -18,11 +18,10 @@ namespace operators { class SigmoidOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); - PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); - outputs[0]->Resize(inputs[0]->dims()); + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -38,8 +37,7 @@ public: class SigmoidOpGrad : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override {} + void InferShape(const InferShapeContext &ctx) const override {} std::string DebugString() const override { LOG(INFO) << "SigmoidGrad"; return ""; diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 896a6f5d83e0f96de50e3aaae6f545172bf5da14..3dd23a9ebc7ac0972d6ee07b9ac051d59e66f62f 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -22,15 +22,14 @@ namespace operators { template class SigmoidKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto input = context.Input(0)->Get(); - auto* output = context.Output(0)->GetMutable(); - + void Compute(const ExecutionContext& context) const override { + auto input = context.Input(0); + auto output = context.Output(0); output->mutable_data(context.GetPlace()); EigenVector::Flatten(*output).device( *(context.GetEigenDevice())) = - 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(input)).exp()); + 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); } }; } // namespace operators diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index df60b62fa6ac8d67c9dadc40ec49aaedab92bc88..5b59fad7d5f9729b0862f8cd78cb32f94f87f513 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -18,14 +18,13 @@ namespace operators { class SoftmaxOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); - PADDLE_ENFORCE(inputs[0]->dims().size() == 2, + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax"); + PADDLE_ENFORCE(ctx.Input(0)->dims().size() == 2, "The input of softmax op must be matrix"); - PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax"); - - outputs[0]->Resize(inputs[0]->dims()); + PADDLE_ENFORCE(ctx.OutputSize() == 1, + "Only one output is need for softmax"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -41,8 +40,7 @@ public: class SoftmaxOpGrad : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override {} + void InferShape(const InferShapeContext &ctx) const override {} std::string DebugString() const override { LOG(INFO) << "SoftmaxOpGrad"; return ""; diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 625a87b58560231572c1cca2a21bd0c47c8cb296..a5c19c5fc7c6f5909dbb355aff09bf15405b6957 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -22,12 +22,12 @@ namespace operators { template class SoftmaxKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto input = context.Input(0)->Get(); - auto* output = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& context) const override { + auto input = context.Input(0); + auto output = context.Output(0); output->mutable_data(context.GetPlace()); - auto logits = EigenMatrix::From(input); + auto logits = EigenMatrix::From(*input); auto softmax = EigenMatrix::From(*output); const int kBatchDim = 0; diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index b712e457ff60e8b30b87c0d549693d53e9f05d59..9d1f5fba2ad3ada4742ada30b41d68d15a69ca45 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -22,7 +22,9 @@ namespace paddle { namespace operators { using OpKernel = framework::OpKernel; -using KernelContext = framework::KernelContext; +using InferShapeContext = framework::InferShapeContext; +using ExecutionContext = framework::ExecutionContext; +using Variable = framework::Variable; template