Better handling of nonlinear constraints in nlopt wrapper

This commit is contained in:
tamasmeszaros 2022-11-29 18:25:39 +01:00
parent cdac790163
commit 056e740027

View File

@ -74,18 +74,28 @@ Fn for_each_argument(Fn &&fn, Args&&...args)
return fn; return fn;
} }
template<class Fn, class...Args> // Call fn on each element of the input tuple tup.
Fn for_each_in_tuple(Fn fn, const std::tuple<Args...> &tup) template<class Fn, class Tup>
Fn for_each_in_tuple(Fn fn, Tup &&tup)
{ {
auto arg = std::tuple_cat(std::make_tuple(fn), tup); auto mpfn = [&fn](auto&...pack) {
auto mpfn = [](auto fn, auto...pack) { for_each_argument(fn, pack...);
return for_each_argument(fn, pack...);
}; };
std::apply(mpfn, arg);
std::apply(mpfn, tup);
return fn; return fn;
} }
// Wrap each element of the tuple tup into a wrapper class W and return
// a new tuple with each element being of type W<T_i> where T_i is the type of
// i-th element of tup.
template<template<class> class W, class...Args>
auto wrap_tup(const std::tuple<Args...> &tup)
{
return std::tuple<W<Args>...>(tup);
}
// Optimizers based on NLopt. // Optimizers based on NLopt.
template<nlopt_algorithm alg> class NLoptOpt<NLoptAlg<alg>> { template<nlopt_algorithm alg> class NLoptOpt<NLoptAlg<alg>> {
protected: protected:
@ -94,32 +104,38 @@ protected:
static constexpr double ConstraintEps = 1e-6; static constexpr double ConstraintEps = 1e-6;
template<class Fn> using TOptData = template<class Fn> struct OptData {
std::tuple<std::remove_reference_t<Fn>*, NLoptOpt*, nlopt_opt>; Fn fn;
NLoptOpt *self = nullptr;
nlopt_opt opt_raw = nullptr;
OptData(const Fn &f): fn{f} {}
OptData(const Fn &f, NLoptOpt *s, nlopt_opt nlopt_raw)
: fn{f}, self{s}, opt_raw{nlopt_raw} {}
};
template<class Fn, size_t N> template<class Fn, size_t N>
static double optfunc(unsigned n, const double *params, static double optfunc(unsigned n, const double *params,
double *gradient, double *gradient, void *data)
void *data)
{ {
assert(n == N); assert(n == N);
auto tdata = static_cast<TOptData<Fn>*>(data); auto tdata = static_cast<OptData<Fn>*>(data);
if (std::get<1>(*tdata)->m_stopcr.stop_condition()) if (tdata->self->m_stopcr.stop_condition())
nlopt_force_stop(std::get<2>(*tdata)); nlopt_force_stop(tdata->opt_raw);
auto fnptr = std::get<0>(*tdata);
auto funval = to_arr<N>(params); auto funval = to_arr<N>(params);
double scoreval = 0.; double scoreval = 0.;
using RetT = decltype((*fnptr)(funval)); using RetT = decltype(tdata->fn(funval));
if constexpr (std::is_convertible_v<RetT, ScoreGradient<N>>) { if constexpr (std::is_convertible_v<RetT, ScoreGradient<N>>) {
ScoreGradient<N> score = (*fnptr)(funval); ScoreGradient<N> score = tdata->fn(funval);
for (size_t i = 0; i < n; ++i) gradient[i] = (*score.gradient)[i]; for (size_t i = 0; i < n; ++i) gradient[i] = (*score.gradient)[i];
scoreval = score.score; scoreval = score.score;
} else { } else {
scoreval = (*fnptr)(funval); scoreval = tdata->fn(funval);
} }
return scoreval; return scoreval;
@ -127,17 +143,14 @@ protected:
template<class Fn, size_t N> template<class Fn, size_t N>
static double constrain_func(unsigned n, const double *params, static double constrain_func(unsigned n, const double *params,
double *gradient, double *gradient, void *data)
void *data)
{ {
assert(n == N); assert(n == N);
auto tdata = static_cast<TOptData<Fn>*>(data); auto tdata = static_cast<OptData<Fn>*>(data);
auto &fnptr = std::get<0>(*tdata);
auto funval = to_arr<N>(params); auto funval = to_arr<N>(params);
return (*fnptr)(funval); return tdata->fn(funval);
} }
template<size_t N> template<size_t N>
@ -173,22 +186,27 @@ protected:
{ {
Result<N> r; Result<N> r;
TOptData<Fn> data = std::make_tuple(&fn, this, nl.ptr); OptData<Fn> data {fn, this, nl.ptr};
auto do_for_each_eq = [this, &nl](auto &&arg) { auto do_for_each_eq = [this, &nl](auto &arg) {
auto data = std::make_tuple(&arg, this, nl.ptr); arg.self = this;
using F = std::remove_cv_t<decltype(arg)>; arg.opt_raw = nl.ptr;
nlopt_add_equality_constraint (nl.ptr, constrain_func<F, N>, &data, ConstraintEps); using F = decltype(arg.fn);
nlopt_add_equality_constraint (nl.ptr, constrain_func<F, N>, &arg, ConstraintEps);
}; };
auto do_for_each_ineq = [this, &nl](auto &&arg) { auto do_for_each_ineq = [this, &nl](auto &arg) {
auto data = std::make_tuple(&arg, this, nl.ptr); arg.self = this;
using F = std::remove_cv_t<decltype(arg)>; arg.opt_raw = nl.ptr;
nlopt_add_inequality_constraint (nl.ptr, constrain_func<F, N>, &data, ConstraintEps); using F = decltype(arg.fn);
nlopt_add_inequality_constraint (nl.ptr, constrain_func<F, N>, &arg, ConstraintEps);
}; };
for_each_in_tuple(do_for_each_eq, equalities); auto eq_data = wrap_tup<OptData>(equalities);
for_each_in_tuple(do_for_each_ineq, inequalities); for_each_in_tuple(do_for_each_eq, eq_data);
auto ineq_data = wrap_tup<OptData>(inequalities);
for_each_in_tuple(do_for_each_ineq, ineq_data);
switch(m_dir) { switch(m_dir) {
case OptDir::MIN: case OptDir::MIN:
@ -260,8 +278,19 @@ public:
const StopCriteria &get_loc_criteria() const noexcept { return m_loc_stopcr; } const StopCriteria &get_loc_criteria() const noexcept { return m_loc_stopcr; }
}; };
template<class Alg> struct AlgFeatures_ {
static constexpr bool SupportsInequalities = false;
static constexpr bool SupportsEqualities = false;
};
} // namespace detail; } // namespace detail;
template<class Alg> constexpr bool SupportsEqualities =
detail::AlgFeatures_<remove_cvref_t<Alg>>::SupportsEqualities;
template<class Alg> constexpr bool SupportsInequalities =
detail::AlgFeatures_<remove_cvref_t<Alg>>::SupportsInequalities;
// Optimizers based on NLopt. // Optimizers based on NLopt.
template<class M> class Optimizer<M, detail::NLoptOnly<M>> { template<class M> class Optimizer<M, detail::NLoptOnly<M>> {
detail::NLoptOpt<M> m_opt; detail::NLoptOpt<M> m_opt;
@ -278,6 +307,14 @@ public:
const std::tuple<EqFns...> &eq_constraints = {}, const std::tuple<EqFns...> &eq_constraints = {},
const std::tuple<IneqFns...> &ineq_constraint = {}) const std::tuple<IneqFns...> &ineq_constraint = {})
{ {
static_assert(std::tuple_size_v<std::tuple<EqFns...>> == 0
|| SupportsEqualities<M>,
"Equality constraints are not supported.");
static_assert(std::tuple_size_v<std::tuple<IneqFns...>> == 0
|| SupportsInequalities<M>,
"Inequality constraints are not supported.");
return m_opt.optimize(std::forward<Func>(func), initvals, bounds, return m_opt.optimize(std::forward<Func>(func), initvals, bounds,
eq_constraints, eq_constraints,
ineq_constraint); ineq_constraint);
@ -299,13 +336,41 @@ public:
}; };
// Predefinded NLopt algorithms // Predefinded NLopt algorithms
using AlgNLoptGenetic = detail::NLoptAlgComb<NLOPT_GN_ESCH>; using AlgNLoptGenetic = detail::NLoptAlgComb<NLOPT_GN_ESCH>;
using AlgNLoptSubplex = detail::NLoptAlg<NLOPT_LN_SBPLX>; using AlgNLoptSubplex = detail::NLoptAlg<NLOPT_LN_SBPLX>;
using AlgNLoptSimplex = detail::NLoptAlg<NLOPT_LN_NELDERMEAD>; using AlgNLoptSimplex = detail::NLoptAlg<NLOPT_LN_NELDERMEAD>;
using AlgNLoptCobyla = detail::NLoptAlg<NLOPT_LN_COBYLA>; using AlgNLoptCobyla = detail::NLoptAlg<NLOPT_LN_COBYLA>;
using AlgNLoptDIRECT = detail::NLoptAlg<NLOPT_GN_DIRECT>; using AlgNLoptDIRECT = detail::NLoptAlg<NLOPT_GN_DIRECT>;
using AlgNLoptISRES = detail::NLoptAlg<NLOPT_GN_ISRES>; using AlgNLoptORIG_DIRECT = detail::NLoptAlg<NLOPT_GN_ORIG_DIRECT>;
using AlgNLoptMLSL = detail::NLoptAlgComb<NLOPT_GN_MLSL, NLOPT_LN_SBPLX>; using AlgNLoptISRES = detail::NLoptAlg<NLOPT_GN_ISRES>;
using AlgNLoptAGS = detail::NLoptAlg<NLOPT_GN_AGS>;
using AlgNLoptMLSL = detail::NLoptAlgComb<NLOPT_GN_MLSL, NLOPT_LN_SBPLX>;
using AlgNLoptMLSL_Cobyla = detail::NLoptAlgComb<NLOPT_GN_MLSL, NLOPT_LN_COBYLA>;
namespace detail {
template<> struct AlgFeatures_<AlgNLoptCobyla> {
static constexpr bool SupportsInequalities = true;
static constexpr bool SupportsEqualities = true;
};
template<> struct AlgFeatures_<AlgNLoptISRES> {
static constexpr bool SupportsInequalities = true;
static constexpr bool SupportsEqualities = false;
};
template<> struct AlgFeatures_<AlgNLoptORIG_DIRECT> {
static constexpr bool SupportsInequalities = true;
static constexpr bool SupportsEqualities = false;
};
template<> struct AlgFeatures_<AlgNLoptAGS> {
static constexpr bool SupportsInequalities = true;
static constexpr bool SupportsEqualities = true;
};
} // namespace detail
}} // namespace Slic3r::opt }} // namespace Slic3r::opt