Commit 6a75e5f5 authored by Jason Rhinelander's avatar Jason Rhinelander

New right/tolerance argument syntax; remove domain_t deduction

This changes the `right` and `tolerance` argument into more complex
objects that allow both a literal, or special objects returned by
`search_right()`, `absolute_tolerance(.001)`, or
`relative_tolerance(.1)`.

This allows the much more descriptive syntax:

    constrained_maximum_search(f, 0, search_right(), relative_tolerance(.001))

This also reintroduces the ability to specify an absolute tolerance
level, which was there once (as a separate argument) but got lost along
the way.

This commit also changes the `domain_t` argument to be no longer
deduced.  Previously it was deduced from `left`, which would have made
the above example rather broken: `domain_t` would have been `int` (from
the 0) rather than `double` (which would require 0.0).  Overriding
`domain_t` now requires an explicitly template parameter:

    constrained_maximum_search<float>(f, 0, search_right(), relative_tolerance(.001))
parent 512cdae1
......@@ -265,6 +265,53 @@ template <typename domain_t = double, typename value_t = double> struct search_r
arg(std::move(a)), value(std::move(v)), inside(in), iterations(it) {}
};
template <typename T> using non_deduced = std::common_type_t<T>;
/** Special class to pass a tolerance into one of the following search functions. This is usually
* constructed via a call to either `absolute_tolerance`, to `relative_tolerance`, or by implicit
* conversion from double (which is equivalent to relative_tolerance). */
template <typename AbsTol_t> struct search_tolerance {
const bool is_relative;
const double relative;
const AbsTol_t absolute;
/** Implicit conversion from a double gives relative tolerance. Negative values are replaced
* with 0. */
search_tolerance(double rel) : is_relative{true}, relative{rel > 0. ? rel : 0.}, absolute{0} {}
/** Constructs an absolute tolerance; this is usually invoked via the absolute_tolerance()
* function.
*/
search_tolerance(AbsTol_t abs, bool /* unused */) : is_relative{false}, relative{0}, absolute{std::move(abs)} {}
search_tolerance(const search_tolerance &) = default;
search_tolerance &operator=(const search_tolerance &) = default;
search_tolerance(search_tolerance &&) = default;
search_tolerance &operator=(search_tolerance &&) = default;
/** Implicit conversion to a tolerance with a different domain type; this casts the absolute
* value from the foreign to the local type; it is only allowed when the absolute type is
* convertible.
*/
template <typename foreign_t, std::enable_if_t<std::is_convertible<foreign_t, AbsTol_t>::value, int> = 0>
search_tolerance(const search_tolerance<foreign_t> &tol)
: is_relative{tol.is_relative}, relative{tol.relative}, absolute{tol.absolute} {}
/** Implicit conversion to a type with a different domain type with non-convertible domain
* types; this is only allowed if the copied-from value is a relative tolerance instance (and
* throws a `domain_error` if not).
*/
template <typename foreign_t, std::enable_if_t<!std::is_convertible<foreign_t, AbsTol_t>::value, int> = 0>
search_tolerance(const search_tolerance<foreign_t> &tol)
: is_relative{tol.is_relative}, relative{tol.relative}, absolute{0} {
if (!is_relative) throw std::domain_error("Cannot cast absolute search_tolerance types");
}
};
/// Constructs a tolerance object that specifies absolute tolerance.
template <typename domain_t> search_tolerance<domain_t> absolute_tolerance(domain_t tol) {
return search_tolerance<domain_t>(tol, true);
}
/// Constructs a tolerance object that specifies relative tolerance.
inline search_tolerance<bool> relative_tolerance(double tol) { return tol; }
/// The constant phi. Callers can specialize this template if using custom types with more
/// precision than a long double value.
template <typename RealType> constexpr RealType phi = 1.61803398874989484820458683436563811L;
......@@ -312,11 +359,9 @@ template <typename RealType> constexpr RealType golden_section_left = RealType(1
template <typename domain_t = double, typename Func, typename value_t = decltype(std::declval<Func>()(std::declval<domain_t>()))>
search_result<domain_t, value_t> single_peak_search(
Func f,
std::common_type_t<domain_t> left,
std::common_type_t<domain_t> right,
std::common_type_t<domain_t> tol_rel = 1e-10) {
if (tol_rel < 0) tol_rel = 0;
non_deduced<domain_t> left,
non_deduced<domain_t> right,
search_tolerance<non_deduced<domain_t>> tolerance = 1e-10) {
constexpr domain_t midpoint_right = phi<domain_t> - domain_t(1);
constexpr domain_t midpoint_left = domain_t(1) - midpoint_right;
......@@ -332,7 +377,8 @@ search_result<domain_t, value_t> single_peak_search(
int iterations = 1; // Count the above mid calcs as an iteration
using std::abs; // Don't use std::abs directly (to allow ADL on abs)
using std::max;
while (span > tol_rel * max(abs(left), abs(right))) {
bool done;
do {
iterations++;
if (fml >= fmr) {
// midleft is the higher point, so we can exclude everything right of midright.
......@@ -365,7 +411,11 @@ search_result<domain_t, value_t> single_peak_search(
swap(midleft, midright);
swap(fml, fmr);
}
}
done = tolerance.is_relative
? span <= tolerance.relative * max(abs(left), abs(right))
: span <= tolerance.absolute;
} while (!done);
// Prefer the end-points for ties (the max might legitimately be an end-point), and prefer left
// over right (for no particularly good reason).
......@@ -379,13 +429,55 @@ search_result<domain_t, value_t> single_peak_search(
return {std::move(midright), std::move(fmr), true, iterations};
}
template <typename T> using non_deduced = T;
/// Divides by 2, but with specializations for floating point types that do so by multiplying by
/// 0.5 instead.
template <typename T, std::enable_if_t< std::is_floating_point<T>::value, int> = 0> T half(T val) { return val * T(0.5); }
template <typename T, std::enable_if_t<!std::is_floating_point<T>::value, int> = 0> T half(T val) { return val / T(2); }
/** Class that specifies a RHS limit when implicitly converted from a domain value, or that
* specifies that the RHS limit should be found automatically when default constructed.
*
* This is not typically invoked directly: instead either specify a value for the `right` argument,
* or specify `search_right()`.
*/
template <typename domain_t>
struct search_right_val {
/// True if right-hand-side value should be found automatically
bool search = true;
/// The right value (if `search` is false).
domain_t right;
/// Default constructor: specifies a `right` argument that should be found automatically.
search_right_val() = default;
/// Implicit conversion from a `domain_t` value: uses the specified `right` value directly
/// (without searching). If the type supports NaN and the given value is NaN, `search` mode is
/// enabled.
search_right_val(domain_t right) : search{false}, right{right} {
using std::isnan;
if (std::numeric_limits<domain_t>::has_quiet_NaN() && isnan(right))
search = true;
}
/** Implicit conversion to a `search_right_val` with a different domain type; this casts the
* right value from the foreign to the local type; it is only allowed when the value type is
* convertible. */
template <typename foreign_t, std::enable_if_t<std::is_convertible<foreign_t, domain_t>::value, int> = 0>
search_right_val(const search_right_val<foreign_t> &srv)
: search{srv.search}, right{srv.right} {}
/** Implicit conversion to a type with a different domain type with non-convertible domain
* types; this is only allowed if the copied-from value is a `search = true` instance (and
* throws a `domain_error` if not).
*/
template <typename foreign_t, std::enable_if_t<!std::is_convertible<foreign_t, domain_t>::value, int> = 0>
search_right_val(const search_right_val<foreign_t> &srv)
: search{srv.search}, right{0} {
if (!search) throw std::domain_error("Cannot cast absolute search_right_val types");
}
};
/** Constructs a `search_right_val` that searches for the right-hand side. */
inline search_right_val<bool> search_right() { return {}; }
/** Performs a binary search to find the maximum function value that satisfies a constraint given
* a pair of values that satisfy and do not satisfy the constraint.
*
......@@ -400,10 +492,18 @@ template <typename T, std::enable_if_t<!std::is_floating_point<T>::value, int> =
* `f(left)` (if not, the algorithms fails immediately).
*
* \param right the right edge of the domain to consider at which the constraint should not be
* satisfied. If specified as NaN the algorithm first starts at x = max{-left, 2*left}, doubling x
* until a constraint violation is encountered, at which point this `x` becomes `right`. (This
* attempts to determine `right` are not counted in iterations in the returned object). If `left`
* is 0, this starts looking at `right` at 1.0 (then doubling, etc.).
* satisfied. Can be specified as a regular value, or as a special `search_right()` value: if the
* latter, the algorithm first starts at x = max{-left, 2*left}, doubling x until a constraint
* violation is encountered, at which point this `x` becomes `right`. (This attempts to determine
* `right` are not counted in iterations in the returned object). If `left` is 0, this starts
* looking at `right` at 1.0 (then doubling, etc.).
*
* \param tolerance the domain value tolerance at which to stop searching: the algoritm proceeds
* until the difference between the left and right edge has been narrowed to this value or less.
* This is typically specified as a double value for relative tolerance (relative to the larger
* absolute value of right or left), or a call to `absolute_tolerance(...)` for absolute tolerance.
* The default, if unspecified, is relative tolerance of 1e-10. It is safe to pass a value of 0
* here: the algorithm will then run to the maximum precision of the domain type.
*
* \return a `eris::search_result` struct. If the initial `f(left)` is not satisfied, this
* immediately returns (with `value` set to false, `inside` set to false, and `arg` set to left).
......@@ -413,9 +513,12 @@ template <typename T, std::enable_if_t<!std::is_floating_point<T>::value, int> =
* `.inside` set to true. (Note that a maximum at exactly `left` is considered "inside" `[left,
* right)`).
*/
template <typename domain_t, typename Func>
template <typename domain_t = double, typename Func>
search_result<domain_t, bool> constrained_maximum_search(
Func f, domain_t left, non_deduced<domain_t> right, non_deduced<domain_t> tol_rel = domain_t(1e-10)) {
Func f,
non_deduced<domain_t> left,
search_right_val<non_deduced<domain_t>> &&s_right,
search_tolerance<non_deduced<domain_t>> tolerance = 1e-10) {
static_assert(std::is_same<bool, decltype(f(left))>::value,
"constrained_maximum_search: given function must return a `bool` value");
......@@ -428,18 +531,26 @@ search_result<domain_t, bool> constrained_maximum_search(
using std::isfinite;
using std::isnan;
if (isnan(right)) {
domain_t right;
if (s_right.search) {
domain_t x = left < domain_t(0) ? -left : left > domain_t(0) ? domain_t(2)*left : domain_t(1);
while (isfinite(x) && f(x)) x *= domain_t(2);
while (isfinite(x) && f(x)) {
left = x;
fl = true;
x *= domain_t(2);
}
right = x;
}
else
right = s_right.right;
bool fr = f(right);
if (fr || !isfinite(right)) return {std::move(right), fr, false, 0};
domain_t span = right - left;
int iterations = 0;
while (span > tol_rel * max(abs(left), abs(right))) {
bool done;
do {
iterations++;
domain_t mid = left + half(span);
if (mid == left || mid == right) break; // Numerical precision limit
......@@ -454,7 +565,11 @@ search_result<domain_t, bool> constrained_maximum_search(
fr = fm;
}
span = right - left;
}
done = tolerance.is_relative
? span <= tolerance.relative * max(abs(left), abs(right))
: span <= tolerance.absolute;
} while (!done);
return {std::move(left), true, true, iterations};
}
......@@ -466,21 +581,30 @@ search_result<domain_t, bool> constrained_maximum_search(
/// `constrained_maximum_search`, `right` can be specified as NaN, in which case the same search for
/// an initial `right` value will be done as in `constrained_maximum_search`, but looking for an
/// initial `right` value that doesn't satisfy the constraint.
template <typename domain_t, typename Func>
template <typename domain_t = double, typename Func>
search_result<domain_t, bool> constrained_minimum_search(
Func f, domain_t left, domain_t right, std::common_type_t<domain_t> tol_rel = domain_t(1e-10)) {
Func f,
non_deduced<domain_t> left,
search_right_val<non_deduced<domain_t>> &&s_right,
search_tolerance<non_deduced<domain_t>> tolerance = 1e-10) {
static_assert(std::is_same<bool, decltype(f(left))>::value,
"constrained_minimum_search: given function must return a `bool` value");
if (isnan(right)) {
domain_t right;
if (s_right.search) {
domain_t x = left < domain_t(0) ? -left : left > domain_t(0) ? domain_t(2)*left : domain_t(1);
while (isfinite(x) && !f(x)) x *= domain_t(2);
while (isfinite(x) && !f(x)) {
left = x;
x *= domain_t(2);
}
right = x;
}
else
right = s_right.right;
auto ret = constrained_maximum_search(
[f = std::move(f)](domain_t a) { return f(-a); },
-right, -left, tol_rel);
-right, -left, tolerance);
ret.arg = -ret.arg;
return ret;
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment