提交 244a9e06 编写于 作者: L liuwei1031 提交者: GitHub

migrate several ops: (#17606)

* migrate several ops:
  mean,
  mean_grad
  fill_constant
  square_grad
  elementwise_sub_grad
  mul_grad

* add sdg_op

* fix kernel platform registration issue

* code cleanup

* fix platform typo
上级 08f0306d
...@@ -25,11 +25,11 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc ...@@ -25,11 +25,11 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc
set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.") "A path setting inference demo download directories.")
# lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc # lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
# DEPS cxx_api_lite model_parser_lite target_wrapper_host # DEPS cxx_api_lite model_parser_lite target_wrapper_host
# ${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model # ${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
# --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) # --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
if(WITH_TESTING) if(WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz")
# add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) # add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz)
......
...@@ -92,7 +92,7 @@ TEST(CXXTrainer, train) { ...@@ -92,7 +92,7 @@ TEST(CXXTrainer, train) {
main_program_desc.ParseFromString(main_program_pb); main_program_desc.ParseFromString(main_program_pb);
startup_program_desc.ParseFromString(startup_program_pb); startup_program_desc.ParseFromString(startup_program_pb);
LOG(INFO) << main_program_desc.DebugString(); // LOG(INFO) << main_program_desc.DebugString();
for (const auto& op : main_program_desc.blocks(0).ops()) { for (const auto& op : main_program_desc.blocks(0).ops()) {
LOG(INFO) << "get op " << op.type(); LOG(INFO) << "get op " << op.type();
......
...@@ -2,5 +2,19 @@ if(NOT LITE_WITH_X86) ...@@ -2,5 +2,19 @@ if(NOT LITE_WITH_X86)
return() return()
endif() endif()
cc_library(activation_compute SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op) cc_library(activation_compute_x86 SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op)
cc_library(elementwise_compute SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op) cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps})
cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_x86 SRCS mul_compute.cc DEPS ${lite_kernel_deps})
cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps})
set(x86_kernels
activation_compute_x86
elementwise_compute_x86
mean_compute_x86
fill_constant_compute_x86
mul_compute_x86
)
set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels")
...@@ -55,7 +55,7 @@ void ActivateGrad(const platform::CPUDeviceContext& context, ...@@ -55,7 +55,7 @@ void ActivateGrad(const platform::CPUDeviceContext& context,
} }
template <typename T> template <typename T>
class SquareCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { class SquareCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public: public:
using param_t = operators::ActivationParam; using param_t = operators::ActivationParam;
...@@ -70,14 +70,11 @@ class SquareCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -70,14 +70,11 @@ class SquareCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
&param.Out->raw_tensor()); &param.Out->raw_tensor());
} }
// TargetType target() const override;
// PrecisionType precision() const override;
virtual ~SquareCompute() = default; virtual ~SquareCompute() = default;
}; };
template <typename T> template <typename T>
class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { class SquareGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public: public:
using param_t = operators::ActivationGradParam; using param_t = operators::ActivationGradParam;
...@@ -93,9 +90,6 @@ class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -93,9 +90,6 @@ class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
&param.X_grad->raw_tensor()); &param.X_grad->raw_tensor());
} }
// TargetType target() const override;
// PrecisionType precision() const override;
virtual ~SquareGradCompute() = default; virtual ~SquareGradCompute() = default;
}; };
...@@ -107,16 +101,16 @@ class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -107,16 +101,16 @@ class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
// float // float
REGISTER_LITE_KERNEL(square, kX86, kFloat, kNCHW, REGISTER_LITE_KERNEL(square, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::SquareCompute<float>, def) paddle::lite::kernels::x86::SquareCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(square_grad, kX86, kFloat, kNCHW, REGISTER_LITE_KERNEL(square_grad, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::SquareGradCompute<float>, def) paddle::lite::kernels::x86::SquareGradCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))}) .BindInput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))}) .BindInput(paddle::framework::GradVarName("Out"),
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("X"),
{LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -32,7 +32,7 @@ struct SubFunctor { ...@@ -32,7 +32,7 @@ struct SubFunctor {
template <typename T> template <typename T>
class ElementwiseSubCompute class ElementwiseSubCompute
: public KernelLite<TARGET(kHost), PRECISION(kFloat)> { : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public: public:
using param_t = operators::ElementwiseParam; using param_t = operators::ElementwiseParam;
...@@ -49,22 +49,67 @@ class ElementwiseSubCompute ...@@ -49,22 +49,67 @@ class ElementwiseSubCompute
&param.Out->raw_tensor()); &param.Out->raw_tensor());
} }
// TargetType target() const override;
// PrecisionType precision() const override;
virtual ~ElementwiseSubCompute() = default; virtual ~ElementwiseSubCompute() = default;
}; };
template <typename T>
struct SubGradDX {
T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename T>
struct SubGradDY {
T operator()(T x, T y, T out, T dout) const { return -dout; }
};
template <typename T>
class ElementwiseSubGradCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseGradParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = context_->As<X86Context>();
CHECK(context.x86_device_context);
param.X_grad->template mutable_data<T>();
param.Y_grad->template mutable_data<T>();
// skip out, x, y
auto dout = param.Out_grad->raw_tensor();
auto dx = param.X_grad->raw_tensor();
auto dy = param.Y_grad->raw_tensor();
auto& skip = dout;
paddle::operators::ElemwiseExplicitGradCompute<
platform::CPUDeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
*context.x86_execution_context, skip, skip, skip, dout, param.axis, &dx,
&dy, SubGradDX<T>(), SubGradDY<T>());
}
virtual ~ElementwiseSubGradCompute() = default;
};
} // namespace x86 } // namespace x86
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// float // float
REGISTER_LITE_KERNEL(square, kHost, kFloat, kNCHW, REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ElementwiseSubCompute<float>, paddle::lite::kernels::x86::ElementwiseSubCompute<float>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_sub_grad, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ElementwiseSubCompute<float>,
def)
.BindInput(paddle::framework::GradVarName("Out"),
{LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("X"),
{LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("Y"),
{LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class FillConstantCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::FillConstantParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = context_->As<X86Context>();
CHECK(context.x86_device_context);
param.Out->template mutable_data<T>();
paddle::operators::math::set_constant(
*context.x86_device_context, &param.Out->raw_tensor(), param.value);
}
virtual ~FillConstantCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// float
REGISTER_LITE_KERNEL(fill_constant, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::FillConstantCompute<float>,
def)
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
class MeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::MeanParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = context_->As<X86Context>();
CHECK(context.x86_device_context);
param.Out->template mutable_data<T>();
auto X = EigenVector<T>::Flatten(param.X->raw_tensor());
auto y = EigenScalar<T>::From(param.Out->raw_tensor());
const auto& place = *(context.x86_device_context->eigen_device());
y.device(place) = X.mean();
}
virtual ~MeanCompute() = default;
};
template <typename T>
class MeanGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::MeanGradParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = context_->As<X86Context>();
CHECK_EQ(param.Out_grad->raw_tensor().numel(), 1);
CHECK(context.x86_device_context);
param.X_grad->template mutable_data<T>();
T x_grad_size = static_cast<T>(param.X_grad->raw_tensor().numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(x_grad_size));
EigenVector<T>::Flatten(param.X_grad->raw_tensor())
.device(*(context.x86_device_context->eigen_device())) =
(EigenVector<T>::From(param.Out_grad->raw_tensor()) / x_grad_size)
.broadcast(bcast);
}
virtual ~MeanGradCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// float
REGISTER_LITE_KERNEL(mean, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::MeanCompute<float>, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(mean_grad, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::MeanGradCompute<float>, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput(paddle::framework::GradVarName("Out"),
{LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("X"),
{LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
using Tensor = framework::Tensor;
template <typename T>
class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void Run() override {
auto& context = context_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulParam>();
CHECK(context.x86_device_context);
param.output->template mutable_data<T>();
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix(
*x, param.x_num_col_dims)
: *x;
const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix(
*y, param.y_num_col_dims)
: *y;
auto* z = &param.output->raw_tensor();
auto z_dim = z->dims();
if (z_dim.size() != 2) {
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context);
blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
}
virtual ~MulCompute() = default;
};
template <typename T>
class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
void Run() override {
auto& context = context_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulGradParam>();
CHECK(context.x86_device_context);
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, param.x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, param.y_num_col_dims)
: static_cast<const Tensor&>(*y);
auto* dout = &param.output_grad->raw_tensor();
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize(
{framework::flatten_to_2d(x->dims(), param.x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), param.y_num_col_dims)[1]});
auto* dx = &param.x_grad->raw_tensor();
auto* dy = &param.y_grad->raw_tensor();
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context);
if (dx) {
// dx->mutable_data<T>(context.x86_device_context->GetPlace());
param.x_grad->template mutable_data<T>();
Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
*dx, param.x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
}
if (dy) {
// dy->yutable_data<T>(context.x86_device_context->GetPlace());
param.y_grad->template mutable_data<T>();
Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
*dy, param.y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
}
}
virtual ~MulGradCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(mul, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::MulCompute<float>, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(mul_grad, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::MulGradCompute<float>, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput(paddle::framework::GradVarName("Out"),
{LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("X"),
{LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("Y"),
{LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/jit/kernels.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override {
auto &context = context_->As<X86Context>();
auto &sgd_param = *param_.get_mutable<operators::SGDParam>();
CHECK(context.x86_device_context);
// param.Out->template mutable_data<T>();
const auto *param = &sgd_param.Param->raw_tensor();
const auto *grad = &sgd_param.Grad->raw_tensor();
const auto *learning_rate = &sgd_param.LearningRate->raw_tensor();
auto *param_out = &sgd_param.ParamOut->raw_tensor();
auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(param->numel(), sz);
PADDLE_ENFORCE_EQ(grad->numel(), sz);
paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->data<T>();
const T *param_data = param->data<T>();
const T *grad_data = grad->data<T>();
int64_t rows_idx = 0;
T *out_data =
param_out->mutable_data<T>(context.x86_device_context->GetPlace());
auto sgd =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>,
platform::CPUPlace>::Cache()
.At(attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
}
virtual ~SGDCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// float
REGISTER_LITE_KERNEL(sgd, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::SGDCompute<float>, def)
.BindInput("Param", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("LearningRate", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Grad", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("ParamOut", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
...@@ -9,6 +9,9 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) ...@@ -9,6 +9,9 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS}) cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS})
cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS}) cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS})
cc_library(mean_op_lite SRCS mean_op.cc DEPS ${op_DEPS})
cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite)
set(ops_lite set(ops_lite
...@@ -19,6 +22,9 @@ set(ops_lite ...@@ -19,6 +22,9 @@ set(ops_lite
feed_op_lite feed_op_lite
fetch_op_lite fetch_op_lite
io_copy_op_lite io_copy_op_lite
elementwise_ops_lite
mean_op_lite
fill_constant_op_lite
PARENT_SCOPE) PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite)
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -36,16 +37,52 @@ class ActivationOp : public OpLite { ...@@ -36,16 +37,52 @@ class ActivationOp : public OpLite {
param_.X = GetVar<lite::Tensor>(scope, X_name); param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<Tensor>(scope, Out_name); param_.Out = GetMutableVar<Tensor>(scope, Out_name);
return true;
} }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "activation_op"; }
private: private:
mutable ActivationParam param_; mutable ActivationParam param_;
}; };
class ActivationGradOp : public OpLite {
public:
explicit ActivationGradOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.X_grad);
CHECK_OR_FALSE(param_.Out_grad);
return true;
}
bool InferShape() const override {
param_.X_grad->Resize(param_.Out_grad->dims());
return true;
}
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front();
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_grad_name);
param_.X_grad = GetMutableVar<Tensor>(scope, X_grad_name);
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "activation_grad_op"; }
private:
mutable ActivationGradParam param_;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(square, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(square, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
...@@ -31,7 +31,7 @@ class ElementwiseOp : public OpLite { ...@@ -31,7 +31,7 @@ class ElementwiseOp : public OpLite {
} }
bool InferShape() const override { bool InferShape() const override {
CHECK_OR_FALSE(param_.X->dims() == param_.Y->dims()); CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size());
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
return true; return true;
} }
...@@ -46,16 +46,64 @@ class ElementwiseOp : public OpLite { ...@@ -46,16 +46,64 @@ class ElementwiseOp : public OpLite {
param_.Y = GetVar<lite::Tensor>(scope, Y_name); param_.Y = GetVar<lite::Tensor>(scope, Y_name);
param_.Out = GetMutableVar<Tensor>(scope, Out_name); param_.Out = GetMutableVar<Tensor>(scope, Out_name);
param_.axis = boost::get<int>(opdesc.GetAttr("axis")); param_.axis = boost::get<int>(opdesc.GetAttr("axis"));
return true;
} }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "elementwise_op"; }
private: private:
mutable operators::ElementwiseParam param_; mutable operators::ElementwiseParam param_;
}; };
class ElementwiseGradExplicitOp : public OpLite {
public:
explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.X_grad);
CHECK_OR_FALSE(param_.Y_grad);
CHECK_OR_FALSE(param_.Out_grad);
return true;
}
bool InferShape() const override {
param_.X_grad->Resize(param_.Out_grad->dims());
param_.Y_grad->Resize(param_.Y->dims());
return true;
}
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
CHECK_EQ(opdesc.Inputs().size(), 1UL);
auto Out_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_name = opdesc.Output(framework::GradVarName("X")).front();
auto Y_name = opdesc.Output(framework::GradVarName("Y")).front();
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_name);
param_.X_grad = GetMutableVar<lite::Tensor>(scope, X_name);
param_.Y_grad = GetMutableVar<Tensor>(scope, Y_name);
param_.axis = boost::get<int>(opdesc.GetAttr("axis"));
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "elementwise_grad_explicit_op";
}
private:
mutable operators::ElementwiseGradParam param_;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(elementwise_sub, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_sub, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_sub_grad,
paddle::lite::operators::ElementwiseGradExplicitOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class FillConstantOp : public OpLite {
public:
explicit FillConstantOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.Out);
return true;
}
bool InferShape() const override {
param_.Out->Resize(param_.shape);
return true;
}
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
CHECK_EQ(opdesc.Inputs().size(), 2UL);
auto Out_name = opdesc.Output("Out").front();
param_.Out = GetMutableVar<Tensor>(scope, Out_name);
param_.dtype = boost::get<int>(opdesc.GetAttr("dtype"));
param_.shape = boost::get<std::vector<int64_t>>(opdesc.GetAttr("shape"));
param_.value = boost::get<float>(opdesc.GetAttr("value"));
param_.force_cpu = boost::get<bool>(opdesc.GetAttr("force_cpu"));
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "fill_constant"; }
private:
mutable operators::FillConstantParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class MeanOp : public OpLite {
public:
explicit MeanOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool InferShape() const override {
param_.Out->Resize(std::vector<int64_t>{1});
return true;
}
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
CHECK_EQ(opdesc.Inputs().size(), 2UL);
auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<Tensor>(scope, Out_name);
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "mean"; }
private:
mutable operators::ElementwiseParam param_;
};
class MeanGradOp : public OpLite {
public:
explicit MeanGradOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out_grad);
CHECK_OR_FALSE(param_.X_grad);
return true;
}
bool InferShape() const override {
param_.X_grad->Resize(param_.X->dims());
// param_.X_grad->set_lod(param_.X->lod());
return true;
}
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
CHECK_EQ(opdesc.Inputs().size(), 3UL);
auto X_name = opdesc.Input("X").front();
auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_grad_name);
param_.X_grad = GetMutableVar<Tensor>(scope, X_grad_name);
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "mean_grad"; }
private:
mutable operators::MeanGradParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(mean, paddle::lite::operators::MeanOp);
REGISTER_LITE_OP(mean_grad, paddle::lite::operators::MeanGradOp);
...@@ -28,9 +28,18 @@ bool MulOpLite::CheckShape() const { ...@@ -28,9 +28,18 @@ bool MulOpLite::CheckShape() const {
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
const auto y_dims = param_.y->dims(); const auto y_dims = param_.y->dims();
CHECK_EQ_OR_FALSE(y_dims.size(), 2UL);
CHECK_GT_OR_FALSE(x_dims.size(), static_cast<size_t>(param_.x_num_col_dims)); CHECK_GT_OR_FALSE(x_dims.size(), static_cast<size_t>(param_.x_num_col_dims));
CHECK_GT_OR_FALSE(y_dims.size(), static_cast<size_t>(param_.y_num_col_dims));
auto x_mat_dims =
framework::flatten_to_2d(x_dims.data(), param_.x_num_col_dims);
auto y_mat_dims =
framework::flatten_to_2d(y_dims.data(), param_.y_num_col_dims);
PADDLE_ENFORCE_EQ(x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's "
"height. %s, %s",
x_mat_dims[1], y_mat_dims[0]);
return true; return true;
} }
...@@ -39,11 +48,16 @@ bool MulOpLite::InferShape() const { ...@@ -39,11 +48,16 @@ bool MulOpLite::InferShape() const {
const auto y_dims = param_.y->dims(); const auto y_dims = param_.y->dims();
// Set output dims // Set output dims
std::vector<int64_t> out_dims(param_.x_num_col_dims + 1, 0); std::vector<int64_t> out_dims(
param_.x_num_col_dims + y_dims.size() - param_.y_num_col_dims, 0);
for (int i = 0; i < param_.x_num_col_dims; ++i) { for (int i = 0; i < param_.x_num_col_dims; ++i) {
out_dims[i] = x_dims[i]; out_dims[i] = x_dims[i];
} }
out_dims.back() = y_dims[1];
for (auto i = static_cast<size_t>(param_.y_num_col_dims); i < y_dims.size();
++i) {
out_dims[i] = y_dims[i];
}
param_.output->Resize(lite::DDim(out_dims)); param_.output->Resize(lite::DDim(out_dims));
...@@ -52,6 +66,38 @@ bool MulOpLite::InferShape() const { ...@@ -52,6 +66,38 @@ bool MulOpLite::InferShape() const {
return true; return true;
} }
bool MulGradOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.y);
CHECK_OR_FALSE(param_.output_grad);
CHECK_OR_FALSE(param_.x_grad);
CHECK_OR_FALSE(param_.y_grad);
return true;
}
bool MulGradOpLite::InferShape() const {
param_.x_grad->Resize(param_.x->dims());
param_.y_grad->Resize(param_.y->dims());
return true;
}
bool MulGradOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
auto X_name = op_desc.Input("X").front();
auto Y_name = op_desc.Input("Y").front();
auto Out_grad_name = op_desc.Output(framework::GradVarName("Out")).front();
auto X_grad_name = op_desc.Output(framework::GradVarName("X")).front();
auto Y_grad_name = op_desc.Output(framework::GradVarName("Y")).front();
param_.x = GetVar<lite::Tensor>(scope, X_name);
param_.y = GetVar<lite::Tensor>(scope, Y_name);
param_.output_grad = GetVar<lite::Tensor>(scope, Out_grad_name);
param_.x_grad = GetMutableVar<lite::Tensor>(scope, X_grad_name);
param_.y_grad = GetMutableVar<lite::Tensor>(scope, Y_grad_name);
return true;
}
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
......
...@@ -61,6 +61,26 @@ class MulOpLite : public OpLite { ...@@ -61,6 +61,26 @@ class MulOpLite : public OpLite {
mutable MulParam param_; mutable MulParam param_;
}; };
class MulGradOpLite : public OpLite {
public:
MulGradOpLite() {}
explicit MulGradOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "mul_grad"; }
private:
mutable MulGradParam param_;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -72,6 +72,17 @@ struct MulParam { ...@@ -72,6 +72,17 @@ struct MulParam {
int y_num_col_dims{1}; int y_num_col_dims{1};
}; };
struct MulGradParam {
const lite::Tensor* x{};
const lite::Tensor* y{};
const lite::Tensor* output_grad{};
lite::Tensor* x_grad{};
lite::Tensor* y_grad{};
int x_num_col_dims{1};
int y_num_col_dims{1};
};
// For Scale Op // For Scale Op
struct ScaleParam { struct ScaleParam {
lite::Tensor* x{}; lite::Tensor* x{};
...@@ -91,9 +102,10 @@ struct ElementwiseParam { ...@@ -91,9 +102,10 @@ struct ElementwiseParam {
}; };
struct ElementwiseGradParam { struct ElementwiseGradParam {
const lite::Tensor* X_grad{}; const lite::Tensor* Y{};
const lite::Tensor* Y_grad{}; const lite::Tensor* Out_grad{};
lite::Tensor* Out_grad{}; lite::Tensor* X_grad{};
lite::Tensor* Y_grad{};
int axis{-1}; // for broadcasting. int axis{-1}; // for broadcasting.
}; };
...@@ -111,6 +123,39 @@ struct ActivationGradParam { ...@@ -111,6 +123,39 @@ struct ActivationGradParam {
const lite::Tensor* Out_grad{}; const lite::Tensor* Out_grad{};
}; };
/// ----------------------- mean operators ----------------------
struct MeanParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
};
struct MeanGradParam {
const lite::Tensor* X{};
const lite::Tensor* Out_grad{};
// for backward
lite::Tensor* X_grad{};
};
/// ----------------------- fill_constant operators ----------------------
struct FillConstantParam {
int dtype{framework::proto::VarType::FP32};
std::vector<int64_t> shape{};
float value{0.0f};
// useless for x86, keep it for compatibility
bool force_cpu{false};
lite::Tensor* Out{};
};
/// ----------------------- sgd operators ----------------------
struct SGDParam {
int dtype{framework::proto::VarType::FP32};
const lite::Tensor* Param{};
const lite::Tensor* LearningRate{};
const lite::Tensor* Grad{};
lite::Tensor* ParamOut{};
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "/paddle/paddle/fluid/lite/operators/sgd_op.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SGDOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Param);
CHECK_OR_FALSE(param_.LearningRate);
CHECK_OR_FALSE(param_.Grad);
CHECK_OR_FALSE(param_.ParamOut);
return true;
}
bool SGDOpLite::InferShape() const {
auto lr_dims = param_.LearningRate->dims().data();
CHECK_EQ_OR_FALSE(framework::product(lr_dims), 1);
param_.ParamOut->Resize(param_.Param->dims());
return true;
}
bool SGDOpLite::AttachImpl(const OpDesc& opdesc, lite::Scope* scope) {
CHECK_EQ(opdesc.Inputs().size(), 3UL);
auto Param_name = opdesc.Input("Param").front();
auto LearningRate_name = opdesc.Input("LearningRate").front();
auto Grad_name = opdesc.Input("Grad").front();
auto ParamOut_name = opdesc.Output("ParamOut").front();
param_.Param = GetVar<lite::Tensor>(scope, Param_name);
param_.LearningRate = GetVar<lite::Tensor>(scope, LearningRate_name);
param_.Grad = GetVar<Tensor>(scope, Grad_name);
param_.ParamOut = GetMutableVar<Tensor>(scope, ParamOut_name);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sgd, paddle::lite::operators::SGDOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SGDOpLite : public OpLite {
public:
SGDOpLite() {}
explicit SGDOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "sgd"; }
private:
mutable SGDParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册