1#ifndef XPED_OPTIM_OPTS_HPP
2#define XPED_OPTIM_OPTS_HPP
4#include <boost/describe.hpp>
8#include "yas/serialize.hpp"
9#include "yas/std_types.hpp"
20BOOST_DEFINE_ENUM_CLASS(Algorithm, L_BFGS, CONJUGATE_GRADIENT, NELDER_MEAD)
22BOOST_DEFINE_ENUM_CLASS(Linesearch, WOLFE, ARMIJO)
26 Algorithm alg = Algorithm::L_BFGS;
28 Linesearch ls = Linesearch::WOLFE;
31 bool bfgs_scaling =
true;
33 double grad_tol = 1.e-6;
34 double step_tol = 1.e-12;
35 double cost_tol = 1.e-12;
37 std::size_t max_steps = 200;
38 std::size_t min_steps = 10;
42 LoadFormat load_format = LoadFormat::NATIVE;
43 std::string load =
"";
50 std::size_t save_period = 0;
52 std::string log_format =
".log";
54 std::filesystem::path working_directory = std::filesystem::current_path();
55 std::filesystem::path logging_directory =
"logs";
57 std::filesystem::path obs_directory =
"obs";
58 bool display_obs =
true;
60 Verbosity verbosity = Verbosity::PER_ITERATION;
62 template <
typename Ar>
65 ar& YAS_OBJECT_NVP(
"OptimOpts",
68 (
"bfgs_scaling", bfgs_scaling),
69 (
"grad_tol", grad_tol),
70 (
"step_tol", step_tol),
71 (
"cost_tol", cost_tol),
72 (
"max_steps", max_steps),
73 (
"min_steps", min_steps),
75 (
"load format", load_format),
76 (
"qn_scale", qn_scale),
79 (
"save_period", save_period),
80 (
"log_format", log_format),
81 (
"working_directory", working_directory),
82 (
"logging_directory", logging_directory),
83 (
"obs_directory", obs_directory),
84 (
"verbosity", verbosity),
85 (
"display_obs", display_obs));
91 fmt::format_to(std::back_inserter(res),
"{}:\n", fmt::styled(
"Optimization options", fmt::emphasis::bold));
92 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• Algorithm:", fmt::streamed(alg));
93 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• Linesearch:", fmt::streamed(ls));
94 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• maximum steps:", max_steps);
95 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• minimum steps:", min_steps);
96 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• gradient tolerance:", grad_tol);
97 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• cost tolerance:", cost_tol);
98 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• step tolerance:", step_tol);
99 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• bfgs scaling:", bfgs_scaling);
100 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• resume:", resume);
101 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• log format:", log_format);
102 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• working directory:", working_directory.string());
103 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• logging directory:", logging_directory.string());
104 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• obs directory:", obs_directory.string());
105 if(load.size() > 0) { fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• load from:", load); }
106 if(load.size() > 0) { fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• scale loaded qn by:", qn_scale); }
107 if(load.size() > 0) { fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• load format:", fmt::streamed(load_format)); }
108 if(load.size() == 0) { fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• seed:", seed); }
109 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• id:",
id);
110 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• save period:", save_period);
111 fmt::format_to(std::back_inserter(res),
" {:<30} {}\n",
"• verbosity:", fmt::streamed(verbosity));
112 fmt::format_to(std::back_inserter(res),
" {:<30} {}",
"• display obs to terminal:", display_obs);
120 if(t.contains(
"algorithm")) { res.
alg = util::enum_from_toml<Algorithm>(t.at(
"algorithm")); }
121 if(t.contains(
"linesearch")) { res.ls = util::enum_from_toml<Linesearch>(t.at(
"linesearch")); }
122 res.grad_tol = t.contains(
"grad_tol") ? t.at(
"grad_tol").as_floating() : res.grad_tol;
123 res.step_tol = t.contains(
"step_tol") ? t.at(
"step_tol").as_floating() : res.step_tol;
124 res.cost_tol = t.contains(
"cost_tol") ? t.at(
"cost_tol").as_floating() : res.cost_tol;
125 res.max_steps = t.contains(
"max_steps") ? t.at(
"max_steps").as_integer() : res.max_steps;
126 res.min_steps = t.contains(
"min_steps") ? t.at(
"min_steps").as_integer() : res.min_steps;
127 res.bfgs_scaling = t.contains(
"bfgs_scaling") ? t.at(
"bfgs_scaling").as_boolean() : res.bfgs_scaling;
128 res.resume = t.contains(
"resume") ? t.at(
"resume").as_boolean() : res.resume;
129 res.load = t.contains(
"load") ?
static_cast<std::string
>(t.at(
"load").as_string()) : res.load;
130 res.qn_scale = t.contains(
"qn_scale") ? (t.at(
"qn_scale").as_integer()) : res.qn_scale;
131 if(t.contains(
"load_format")) { res.load_format = util::enum_from_toml<LoadFormat>(t.at(
"load_format")); }
132 res.save_period = t.contains(
"save_period") ? t.at(
"save_period").as_integer() : res.save_period;
133 res.log_format = t.contains(
"log_format") ?
static_cast<std::string
>(t.at(
"log_format").as_string()) : res.log_format;
134 if(t.contains(
"working_directory")) {
135 std::filesystem::path tmp_wd(
static_cast<std::string
>(t.at(
"working_directory").as_string()));
136 if(tmp_wd.is_relative()) {
137 res.working_directory = std::filesystem::current_path() / tmp_wd;
139 res.working_directory = tmp_wd;
142 if(t.contains(
"logging_directory")) {
143 res.logging_directory = std::filesystem::path(
static_cast<std::string
>(t.at(
"logging_directory").as_string()));
145 if(t.contains(
"obs_directory")) { res.obs_directory = std::filesystem::path(
static_cast<std::string
>(t.at(
"obs_directory").as_string())); }
146 if(t.contains(
"verbosity")) { res.verbosity = util::enum_from_toml<Verbosity>(t.at(
"verbosity")); }
147 res.display_obs = t.contains(
"display_obs") ? t.at(
"display_obs").as_boolean() : res.display_obs;
148 res.seed = t.contains(
"seed") ? (t.at(
"seed").as_integer()) : res.seed;
149 res.id = t.contains(
"id") ? (t.at(
"id").as_integer()) : res.id;
Optim optim_from_toml(const toml::value &t)
Definition: OptimOpts.hpp:117
Definition: OptimOpts.hpp:25
Algorithm alg
Definition: OptimOpts.hpp:26
auto info()
Definition: OptimOpts.hpp:88
void serialize(Ar &ar)
Definition: OptimOpts.hpp:63
double bfgs_xxx
Definition: OptimOpts.hpp:30