You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

928 lines
30 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Classes to accumulate arc weights. Useful for weight lookahead.
  19. #ifndef FST_ACCUMULATOR_H_
  20. #define FST_ACCUMULATOR_H_
  21. #include <sys/types.h>
  22. #include <algorithm>
  23. #include <cstddef>
  24. #include <functional>
  25. #include <memory>
  26. #include <utility>
  27. #include <vector>
  28. #include <fst/log.h>
  29. #include <fst/arcfilter.h>
  30. #include <fst/arcsort.h>
  31. #include <fst/dfs-visit.h>
  32. #include <fst/expanded-fst.h>
  33. #include <fst/float-weight.h>
  34. #include <fst/fst.h>
  35. #include <fst/replace.h>
  36. #include <fst/util.h>
  37. #include <fst/weight.h>
  38. #include <unordered_map>
  39. namespace fst {
  40. // This class accumulates arc weights using the semiring Plus().
  41. // Sum(w, aiter, begin, end) has time complexity O(begin - end).
  42. template <class A>
  43. class DefaultAccumulator {
  44. public:
  45. using Arc = A;
  46. using StateId = typename Arc::StateId;
  47. using Weight = typename Arc::Weight;
  48. DefaultAccumulator() = default;
  49. DefaultAccumulator(const DefaultAccumulator &acc, bool safe = false) {}
  50. void Init(const Fst<Arc> &fst, bool copy = false) {}
  51. void SetState(StateId state) {}
  52. Weight Sum(Weight w, Weight v) { return Plus(w, v); }
  53. template <class ArcIter>
  54. Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
  55. Adder<Weight> adder(w); // maintains cumulative sum accurately
  56. aiter->Seek(begin);
  57. for (auto pos = begin; pos < end; aiter->Next(), ++pos)
  58. adder.Add(aiter->Value().weight);
  59. return adder.Sum();
  60. }
  61. constexpr bool Error() const { return false; }
  62. private:
  63. DefaultAccumulator &operator=(const DefaultAccumulator &) = delete;
  64. };
  65. // This class accumulates arc weights using the log semiring Plus() assuming an
  66. // arc weight has a WeightConvert specialization to and from log64 weights.
  67. // Sum(w, aiter, begin, end) has time complexity O(begin - end).
  68. template <class A>
  69. class LogAccumulator {
  70. public:
  71. using Arc = A;
  72. using StateId = typename Arc::StateId;
  73. using Weight = typename Arc::Weight;
  74. LogAccumulator() = default;
  75. LogAccumulator(const LogAccumulator &acc, bool safe = false) {}
  76. void Init(const Fst<Arc> &fst, bool copy = false) {}
  77. void SetState(StateId s) {}
  78. Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
  79. template <class ArcIter>
  80. Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
  81. auto sum = w;
  82. aiter->Seek(begin);
  83. for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
  84. sum = LogPlus(sum, aiter->Value().weight);
  85. }
  86. return sum;
  87. }
  88. constexpr bool Error() const { return false; }
  89. private:
  90. Weight LogPlus(Weight w, Weight v) {
  91. if (w == Weight::Zero()) {
  92. return v;
  93. }
  94. const auto f1 = to_log_weight_(w).Value();
  95. const auto f2 = to_log_weight_(v).Value();
  96. if (f1 > f2) {
  97. return to_weight_(Log64Weight(f2 - internal::LogPosExp(f1 - f2)));
  98. } else {
  99. return to_weight_(Log64Weight(f1 - internal::LogPosExp(f2 - f1)));
  100. }
  101. }
  102. const WeightConvert<Weight, Log64Weight> to_log_weight_{};
  103. const WeightConvert<Log64Weight, Weight> to_weight_{};
  104. LogAccumulator &operator=(const LogAccumulator &) = delete;
  105. };
  106. // Interface for shareable data for fast log accumulator copies. Holds pointers
  107. // to data only, storage is provided by derived classes.
  108. class FastLogAccumulatorData {
  109. public:
  110. FastLogAccumulatorData(int arc_limit, int arc_period)
  111. : arc_limit_(arc_limit),
  112. arc_period_(arc_period),
  113. weights_ptr_(nullptr),
  114. num_weights_(0),
  115. weight_positions_ptr_(nullptr),
  116. num_positions_(0) {}
  117. virtual ~FastLogAccumulatorData() = default;
  118. // Cummulative weight per state for all states s.t. # of arcs > arc_limit_
  119. // with arcs in order. The first element per state is Log64Weight::Zero().
  120. const double *Weights() const { return weights_ptr_; }
  121. int NumWeights() const { return num_weights_; }
  122. // Maps from state to corresponding beginning weight position in weights_.
  123. // osition -1 means no pre-computed weights for that state.
  124. const int *WeightPositions() const { return weight_positions_ptr_; }
  125. int NumPositions() const { return num_positions_; }
  126. int ArcLimit() const { return arc_limit_; }
  127. int ArcPeriod() const { return arc_period_; }
  128. // Returns true if the data object is mutable and supports SetData().
  129. virtual bool IsMutable() const = 0;
  130. // Does not take ownership but may invalidate the contents of weights and
  131. // weight_positions.
  132. virtual void SetData(std::vector<double> *weights,
  133. std::vector<int> *weight_positions) = 0;
  134. protected:
  135. void Init(int num_weights, const double *weights, int num_positions,
  136. const int *weight_positions) {
  137. weights_ptr_ = weights;
  138. num_weights_ = num_weights;
  139. weight_positions_ptr_ = weight_positions;
  140. num_positions_ = num_positions;
  141. }
  142. private:
  143. const int arc_limit_;
  144. const int arc_period_;
  145. const double *weights_ptr_;
  146. int num_weights_;
  147. const int *weight_positions_ptr_;
  148. int num_positions_;
  149. FastLogAccumulatorData(const FastLogAccumulatorData &) = delete;
  150. FastLogAccumulatorData &operator=(const FastLogAccumulatorData &) = delete;
  151. };
  152. // FastLogAccumulatorData with mutable storage; filled by
  153. // FastLogAccumulator::Init.
  154. class MutableFastLogAccumulatorData : public FastLogAccumulatorData {
  155. public:
  156. MutableFastLogAccumulatorData(int arc_limit, int arc_period)
  157. : FastLogAccumulatorData(arc_limit, arc_period) {}
  158. bool IsMutable() const override { return true; }
  159. void SetData(std::vector<double> *weights,
  160. std::vector<int> *weight_positions) override {
  161. weights_.swap(*weights);
  162. weight_positions_.swap(*weight_positions);
  163. Init(weights_.size(), weights_.data(), weight_positions_.size(),
  164. weight_positions_.data());
  165. }
  166. private:
  167. std::vector<double> weights_;
  168. std::vector<int> weight_positions_;
  169. MutableFastLogAccumulatorData(const MutableFastLogAccumulatorData &) = delete;
  170. MutableFastLogAccumulatorData &operator=(
  171. const MutableFastLogAccumulatorData &) = delete;
  172. };
  173. // This class accumulates arc weights using the log semiring Plus() assuming an
  174. // arc weight has a WeightConvert specialization to and from log64 weights. The
  175. // member function Init(fst) has to be called to setup pre-computed weight
  176. // information.
  177. // Sum(w, aiter, begin, end) has time complexity O(arc_limit_) or O(arc_period_)
  178. // depending on whether the state has more than arc_limit_ arcs
  179. // Space complexity is O(CountStates(fst) + CountArcs(fst) / arc_period_).
  180. template <class A>
  181. class FastLogAccumulator {
  182. public:
  183. using Arc = A;
  184. using StateId = typename Arc::StateId;
  185. using Weight = typename Arc::Weight;
  186. explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
  187. : to_log_weight_(),
  188. to_weight_(),
  189. arc_limit_(arc_limit),
  190. arc_period_(arc_period),
  191. data_(std::make_shared<MutableFastLogAccumulatorData>(arc_limit,
  192. arc_period)),
  193. state_weights_(nullptr),
  194. error_(false) {}
  195. explicit FastLogAccumulator(std::shared_ptr<FastLogAccumulatorData> data)
  196. : to_log_weight_(),
  197. to_weight_(),
  198. arc_limit_(data->ArcLimit()),
  199. arc_period_(data->ArcPeriod()),
  200. data_(data),
  201. state_weights_(nullptr),
  202. error_(false) {}
  203. FastLogAccumulator(const FastLogAccumulator &acc, bool safe = false)
  204. : to_log_weight_(),
  205. to_weight_(),
  206. arc_limit_(acc.arc_limit_),
  207. arc_period_(acc.arc_period_),
  208. data_(acc.data_),
  209. state_weights_(nullptr),
  210. error_(acc.error_) {}
  211. void SetState(StateId s) {
  212. const auto *weights = data_->Weights();
  213. const auto *weight_positions = data_->WeightPositions();
  214. state_weights_ = nullptr;
  215. if (s < data_->NumPositions()) {
  216. const auto pos = weight_positions[s];
  217. if (pos >= 0) state_weights_ = &(weights[pos]);
  218. }
  219. }
  220. Weight Sum(Weight w, Weight v) const { return LogPlus(w, v); }
  221. template <class ArcIter>
  222. Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const {
  223. if (error_) return Weight::NoWeight();
  224. auto sum = w;
  225. // Finds begin and end of pre-stored weights.
  226. ssize_t index_begin = -1;
  227. ssize_t index_end = -1;
  228. ssize_t stored_begin = end;
  229. ssize_t stored_end = end;
  230. if (state_weights_) {
  231. index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0;
  232. index_end = end / arc_period_;
  233. stored_begin = index_begin * arc_period_;
  234. stored_end = index_end * arc_period_;
  235. }
  236. // Computes sum before pre-stored weights.
  237. if (begin < stored_begin) {
  238. const auto pos_end = std::min(stored_begin, end);
  239. aiter->Seek(begin);
  240. for (auto pos = begin; pos < pos_end; aiter->Next(), ++pos) {
  241. sum = LogPlus(sum, aiter->Value().weight);
  242. }
  243. }
  244. // Computes sum between pre-stored weights.
  245. if (stored_begin < stored_end) {
  246. const auto f1 = state_weights_[index_end];
  247. const auto f2 = state_weights_[index_begin];
  248. if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2));
  249. // Commented out for efficiency; adds Zero().
  250. /*
  251. else {
  252. // explicitly computes if cumulative sum lacks precision
  253. aiter->Seek(stored_begin);
  254. for (auto pos = stored_begin; pos < stored_end; aiter->Next(), ++pos)
  255. sum = LogPlus(sum, aiter->Value().weight);
  256. }
  257. */
  258. }
  259. // Computes sum after pre-stored weights.
  260. if (stored_end < end) {
  261. const auto pos_start = std::max(stored_begin, stored_end);
  262. aiter->Seek(pos_start);
  263. for (auto pos = pos_start; pos < end; aiter->Next(), ++pos) {
  264. sum = LogPlus(sum, aiter->Value().weight);
  265. }
  266. }
  267. return sum;
  268. }
  269. template <class FST>
  270. void Init(const FST &fst, bool copy = false) {
  271. if (copy || !data_->IsMutable()) return;
  272. if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) {
  273. FSTERROR() << "FastLogAccumulator: Initialization error";
  274. error_ = true;
  275. return;
  276. }
  277. std::vector<double> weights;
  278. std::vector<int> weight_positions;
  279. weight_positions.reserve(CountStates(fst));
  280. for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) {
  281. const auto s = siter.Value();
  282. if (fst.NumArcs(s) >= arc_limit_) {
  283. auto sum = FloatLimits<double>::PosInfinity();
  284. if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1);
  285. weight_positions[s] = weights.size();
  286. weights.push_back(sum);
  287. size_t narcs = 0;
  288. ArcIterator<FST> aiter(fst, s);
  289. aiter.SetFlags(kArcWeightValue | kArcNoCache, kArcFlags);
  290. for (; !aiter.Done(); aiter.Next()) {
  291. const auto &arc = aiter.Value();
  292. sum = LogPlus(sum, arc.weight);
  293. // Stores cumulative weight distribution per arc_period_.
  294. if (++narcs % arc_period_ == 0) weights.push_back(sum);
  295. }
  296. }
  297. }
  298. data_->SetData(&weights, &weight_positions);
  299. }
  300. bool Error() const { return error_; }
  301. std::shared_ptr<FastLogAccumulatorData> GetData() const { return data_; }
  302. private:
  303. static double LogPosExp(double x) {
  304. return x == FloatLimits<double>::PosInfinity() ? 0.0
  305. : log(1.0F + exp(-x));
  306. }
  307. static double LogMinusExp(double x) {
  308. return x == FloatLimits<double>::PosInfinity() ? 0.0
  309. : log(1.0F - exp(-x));
  310. }
  311. Weight LogPlus(Weight w, Weight v) const {
  312. if (w == Weight::Zero()) {
  313. return v;
  314. }
  315. const auto f1 = to_log_weight_(w).Value();
  316. const auto f2 = to_log_weight_(v).Value();
  317. if (f1 > f2) {
  318. return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2)));
  319. } else {
  320. return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1)));
  321. }
  322. }
  323. double LogPlus(double f1, Weight v) const {
  324. const auto f2 = to_log_weight_(v).Value();
  325. if (f1 == FloatLimits<double>::PosInfinity()) {
  326. return f2;
  327. } else if (f1 > f2) {
  328. return f2 - LogPosExp(f1 - f2);
  329. } else {
  330. return f1 - LogPosExp(f2 - f1);
  331. }
  332. }
  333. // Assumes f1 < f2.
  334. Weight LogMinus(double f1, double f2) const {
  335. if (f2 == FloatLimits<double>::PosInfinity()) {
  336. return to_weight_(Log64Weight(f1));
  337. } else {
  338. return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1)));
  339. }
  340. }
  341. const WeightConvert<Weight, Log64Weight> to_log_weight_{};
  342. const WeightConvert<Log64Weight, Weight> to_weight_{};
  343. const ssize_t arc_limit_; // Minimum number of arcs to pre-compute state.
  344. const ssize_t arc_period_; // Saves cumulative weights per arc_period_.
  345. std::shared_ptr<FastLogAccumulatorData> data_;
  346. const double *state_weights_;
  347. bool error_;
  348. FastLogAccumulator &operator=(const FastLogAccumulator &) = delete;
  349. };
  350. // Stores shareable data for cache log accumulator copies. All copies share the
  351. // same cache.
  352. template <class Arc>
  353. class CacheLogAccumulatorData {
  354. public:
  355. using StateId = typename Arc::StateId;
  356. using Weight = typename Arc::Weight;
  357. CacheLogAccumulatorData(bool gc, size_t gc_limit)
  358. : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
  359. CacheLogAccumulatorData(const CacheLogAccumulatorData<Arc> &data)
  360. : cache_gc_(data.cache_gc_),
  361. cache_limit_(data.cache_limit_),
  362. cache_size_(0) {}
  363. bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
  364. std::vector<double> *GetWeights(StateId s) {
  365. if (auto it = cache_.find(s); it != cache_.end()) {
  366. it->second.recent = true;
  367. return it->second.weights.get();
  368. } else {
  369. return nullptr;
  370. }
  371. }
  372. void AddWeights(StateId s, std::unique_ptr<std::vector<double>> weights) {
  373. if (cache_gc_ && cache_size_ >= cache_limit_) GC(false);
  374. if (cache_gc_) cache_size_ += weights->capacity() * sizeof(double);
  375. cache_.emplace(s, CacheState(std::move(weights), true));
  376. }
  377. private:
  378. // Cached information for a given state.
  379. struct CacheState {
  380. std::unique_ptr<std::vector<double>> weights; // Accumulated weights.
  381. bool recent; // Has this state been accessed since last GC?
  382. CacheState(std::unique_ptr<std::vector<double>> weights, bool recent)
  383. : weights(std::move(weights)), recent(recent) {}
  384. };
  385. // Garbage collect: Deletes from cache states that have not been accessed
  386. // since the last GC ('free_recent = false') until 'cache_size_' is 2/3 of
  387. // 'cache_limit_'. If it does not free enough memory, start deleting
  388. // recently accessed states.
  389. void GC(bool free_recent) {
  390. auto cache_target = (2 * cache_limit_) / 3 + 1;
  391. auto it = cache_.begin();
  392. while (it != cache_.end() && cache_size_ > cache_target) {
  393. auto &cs = it->second;
  394. if (free_recent || !cs.recent) {
  395. cache_size_ -= cs.weights->capacity() * sizeof(double);
  396. cache_.erase(it++);
  397. } else {
  398. cs.recent = false;
  399. ++it;
  400. }
  401. }
  402. if (!free_recent && cache_size_ > cache_target) GC(true);
  403. }
  404. std::unordered_map<StateId, CacheState> cache_; // Cache.
  405. bool cache_gc_; // Enables garbage collection.
  406. size_t cache_limit_; // # of bytes cached.
  407. size_t cache_size_; // # of bytes allowed before GC.
  408. CacheLogAccumulatorData &operator=(const CacheLogAccumulatorData &) = delete;
  409. };
  410. // This class accumulates arc weights using the log semiring Plus() has a
  411. // WeightConvert specialization to and from log64 weights. It is similar to the
  412. // FastLogAccumator. However here, the accumulated weights are pre-computed and
  413. // stored only for the states that are visited. The member function Init(fst)
  414. // has to be called to setup this accumulator. Space complexity is O(gc_limit).
  415. template <class Arc>
  416. class CacheLogAccumulator {
  417. public:
  418. using StateId = typename Arc::StateId;
  419. using Weight = typename Arc::Weight;
  420. explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
  421. size_t gc_limit = 10 * 1024 * 1024)
  422. : arc_limit_(arc_limit),
  423. data_(std::make_shared<CacheLogAccumulatorData<Arc>>(gc, gc_limit)),
  424. s_(kNoStateId),
  425. error_(false) {}
  426. CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe = false)
  427. : arc_limit_(acc.arc_limit_),
  428. fst_(acc.fst_ ? acc.fst_->Copy() : nullptr),
  429. data_(safe ? std::make_shared<CacheLogAccumulatorData<Arc>>(*acc.data_)
  430. : acc.data_),
  431. s_(kNoStateId),
  432. error_(acc.error_) {}
  433. // Argument arc_limit specifies the minimum number of arcs to pre-compute.
  434. void Init(const Fst<Arc> &fst, bool copy = false) {
  435. if (!copy && fst_) {
  436. FSTERROR() << "CacheLogAccumulator: Initialization error";
  437. error_ = true;
  438. return;
  439. }
  440. fst_.reset(fst.Copy());
  441. }
  442. void SetState(StateId s, int depth = 0) {
  443. if (s == s_) return;
  444. s_ = s;
  445. if (data_->CacheDisabled() || error_) {
  446. weights_ = nullptr;
  447. return;
  448. }
  449. if (!fst_) {
  450. FSTERROR() << "CacheLogAccumulator::SetState: Incorrectly initialized";
  451. error_ = true;
  452. weights_ = nullptr;
  453. return;
  454. }
  455. weights_ = data_->GetWeights(s);
  456. if ((weights_ == nullptr) && (fst_->NumArcs(s) >= arc_limit_)) {
  457. auto weights = std::make_unique<std::vector<double>>();
  458. weights->reserve(fst_->NumArcs(s) + 1);
  459. weights->push_back(FloatLimits<double>::PosInfinity());
  460. // `weights` holds a reference to the weight vector, whose ownership is
  461. // transferred to `data_`.
  462. weights_ = weights.get();
  463. data_->AddWeights(s, std::move(weights));
  464. }
  465. }
  466. Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
  467. template <class ArcIter>
  468. Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
  469. if (weights_ == nullptr) {
  470. auto sum = w;
  471. aiter->Seek(begin);
  472. for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
  473. sum = LogPlus(sum, aiter->Value().weight);
  474. }
  475. return sum;
  476. } else {
  477. Extend(end, aiter);
  478. const auto &f1 = (*weights_)[end];
  479. const auto &f2 = (*weights_)[begin];
  480. if (f1 < f2) {
  481. return LogPlus(w, LogMinus(f1, f2));
  482. } else {
  483. // Commented out for efficiency; adds Zero().
  484. /*
  485. auto sum = w;
  486. // Explicitly computes if cumulative sum lacks precision.
  487. aiter->Seek(begin);
  488. for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
  489. sum = LogPlus(sum, aiter->Value().weight);
  490. }
  491. return sum;
  492. */
  493. return w;
  494. }
  495. }
  496. }
  497. // Returns first position from aiter->Position() whose accumulated
  498. // value is greater or equal to w (w.r.t. Zero() < One()). The
  499. // iterator may be repositioned.
  500. template <class ArcIter>
  501. size_t LowerBound(Weight w, ArcIter *aiter) {
  502. const auto f = to_log_weight_(w).Value();
  503. auto pos = aiter->Position();
  504. if (weights_) {
  505. Extend(fst_->NumArcs(s_), aiter);
  506. return std::lower_bound(weights_->begin() + pos + 1, weights_->end(), f,
  507. std::greater<double>()) -
  508. weights_->begin() - 1;
  509. } else {
  510. size_t n = 0;
  511. auto x = FloatLimits<double>::PosInfinity();
  512. for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
  513. x = LogPlus(x, aiter->Value().weight);
  514. if (n >= pos && x <= f) break;
  515. }
  516. return n;
  517. }
  518. }
  519. bool Error() const { return error_; }
  520. private:
  521. double LogPosExp(double x) {
  522. return x == FloatLimits<double>::PosInfinity() ? 0.0
  523. : log(1.0F + exp(-x));
  524. }
  525. double LogMinusExp(double x) {
  526. return x == FloatLimits<double>::PosInfinity() ? 0.0
  527. : log(1.0F - exp(-x));
  528. }
  529. Weight LogPlus(Weight w, Weight v) {
  530. if (w == Weight::Zero()) {
  531. return v;
  532. }
  533. const auto f1 = to_log_weight_(w).Value();
  534. const auto f2 = to_log_weight_(v).Value();
  535. if (f1 > f2) {
  536. return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2)));
  537. } else {
  538. return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1)));
  539. }
  540. }
  541. double LogPlus(double f1, Weight v) {
  542. const auto f2 = to_log_weight_(v).Value();
  543. if (f1 == FloatLimits<double>::PosInfinity()) {
  544. return f2;
  545. } else if (f1 > f2) {
  546. return f2 - LogPosExp(f1 - f2);
  547. } else {
  548. return f1 - LogPosExp(f2 - f1);
  549. }
  550. }
  551. // Assumes f1 < f2.
  552. Weight LogMinus(double f1, double f2) {
  553. if (f2 == FloatLimits<double>::PosInfinity()) {
  554. return to_weight_(Log64Weight(f1));
  555. } else {
  556. return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1)));
  557. }
  558. }
  559. // Extends weights up to index 'end'.
  560. template <class ArcIter>
  561. void Extend(ssize_t end, ArcIter *aiter) {
  562. if (weights_->size() <= end) {
  563. for (aiter->Seek(weights_->size() - 1); weights_->size() <= end;
  564. aiter->Next()) {
  565. weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight));
  566. }
  567. }
  568. }
  569. const WeightConvert<Weight, Log64Weight> to_log_weight_{};
  570. const WeightConvert<Log64Weight, Weight> to_weight_{};
  571. ssize_t arc_limit_; // Minimum # of arcs to cache a state.
  572. std::vector<double> *weights_; // Accumulated weights for cur. state.
  573. // Pointee owned by `data_`.
  574. std::unique_ptr<const Fst<Arc>> fst_; // Input FST.
  575. std::shared_ptr<CacheLogAccumulatorData<Arc>> data_; // Cache data.
  576. StateId s_; // Current state.
  577. bool error_;
  578. };
  579. // Stores shareable data for replace accumulator copies.
  580. template <class Accumulator, class T>
  581. class ReplaceAccumulatorData {
  582. public:
  583. using Arc = typename Accumulator::Arc;
  584. using Label = typename Arc::Label;
  585. using StateId = typename Arc::StateId;
  586. using StateTable = T;
  587. using StateTuple = typename StateTable::StateTuple;
  588. ReplaceAccumulatorData() : state_table_(nullptr) {}
  589. explicit ReplaceAccumulatorData(
  590. std::vector<std::unique_ptr<Accumulator>> &&accumulators)
  591. : state_table_(nullptr), accumulators_(std::move(accumulators)) {}
  592. void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
  593. const StateTable *state_table) {
  594. state_table_ = state_table;
  595. accumulators_.resize(fst_tuples.size());
  596. for (Label i = 0; i < accumulators_.size(); ++i) {
  597. if (!accumulators_[i]) {
  598. accumulators_[i] = std::make_unique<Accumulator>();
  599. accumulators_[i]->Init(*(fst_tuples[i].second));
  600. }
  601. fst_array_.emplace_back(fst_tuples[i].second->Copy());
  602. }
  603. }
  604. const StateTuple &GetTuple(StateId s) const { return state_table_->Tuple(s); }
  605. Accumulator *GetAccumulator(size_t i) { return accumulators_[i].get(); }
  606. const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
  607. private:
  608. const StateTable *state_table_;
  609. std::vector<std::unique_ptr<Accumulator>> accumulators_;
  610. std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
  611. };
  612. // This class accumulates weights in a ReplaceFst. The 'Init' method takes as
  613. // input the argument used to build the ReplaceFst and the ReplaceFst state
  614. // table. It uses accumulators of type 'Accumulator' in the underlying FSTs.
  615. template <class Accumulator,
  616. class T = DefaultReplaceStateTable<typename Accumulator::Arc>>
  617. class ReplaceAccumulator {
  618. public:
  619. using Arc = typename Accumulator::Arc;
  620. using Label = typename Arc::Label;
  621. using StateId = typename Arc::StateId;
  622. using StateTable = T;
  623. using StateTuple = typename StateTable::StateTuple;
  624. using Weight = typename Arc::Weight;
  625. ReplaceAccumulator()
  626. : init_(false),
  627. data_(std::make_shared<
  628. ReplaceAccumulatorData<Accumulator, StateTable>>()),
  629. error_(false) {}
  630. explicit ReplaceAccumulator(
  631. std::vector<std::unique_ptr<Accumulator>> &&accumulators)
  632. : init_(false),
  633. data_(std::make_shared<ReplaceAccumulatorData<Accumulator, StateTable>>(
  634. std::move(accumulators))),
  635. error_(false) {}
  636. ReplaceAccumulator(const ReplaceAccumulator<Accumulator, StateTable> &acc,
  637. bool safe = false)
  638. : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
  639. if (!init_) {
  640. FSTERROR() << "ReplaceAccumulator: Can't copy unintialized accumulator";
  641. }
  642. if (safe) FSTERROR() << "ReplaceAccumulator: Safe copy not supported";
  643. }
  644. // Does not take ownership of the state table, the state table is owned by
  645. // the ReplaceFst.
  646. void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
  647. const StateTable *state_table) {
  648. init_ = true;
  649. data_->Init(fst_tuples, state_table);
  650. }
  651. // Method required by LookAheadMatcher. However, ReplaceAccumulator needs to
  652. // be initialized by calling the Init method above before being passed to
  653. // LookAheadMatcher.
  654. //
  655. // TODO(allauzen): Revisit this. Consider creating a method
  656. // Init(const ReplaceFst<A, T, C>&, bool) and using friendship to get access
  657. // to the innards of ReplaceFst.
  658. void Init(const Fst<Arc> &fst, bool copy = false) {
  659. if (!init_) {
  660. FSTERROR() << "ReplaceAccumulator::Init: Accumulator needs to be"
  661. << " initialized before being passed to LookAheadMatcher";
  662. error_ = true;
  663. }
  664. }
  665. void SetState(StateId s) {
  666. if (!init_) {
  667. FSTERROR() << "ReplaceAccumulator::SetState: Incorrectly initialized";
  668. error_ = true;
  669. return;
  670. }
  671. auto tuple = data_->GetTuple(s);
  672. fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based.
  673. data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
  674. if ((tuple.prefix_id != 0) &&
  675. (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
  676. offset_ = 1;
  677. offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
  678. } else {
  679. offset_ = 0;
  680. offset_weight_ = Weight::Zero();
  681. }
  682. aiter_ = std::make_unique<ArcIterator<Fst<Arc>>>(*data_->GetFst(fst_id_),
  683. tuple.fst_state);
  684. }
  685. Weight Sum(Weight w, Weight v) {
  686. if (error_) return Weight::NoWeight();
  687. return data_->GetAccumulator(fst_id_)->Sum(w, v);
  688. }
  689. template <class ArcIter>
  690. Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
  691. if (error_) return Weight::NoWeight();
  692. auto sum = begin == end ? Weight::Zero()
  693. : data_->GetAccumulator(fst_id_)->Sum(
  694. w, aiter_.get(), begin ? begin - offset_ : 0,
  695. end - offset_);
  696. if (begin == 0 && end != 0 && offset_ > 0) sum = Sum(offset_weight_, sum);
  697. return sum;
  698. }
  699. bool Error() const { return error_; }
  700. private:
  701. bool init_;
  702. std::shared_ptr<ReplaceAccumulatorData<Accumulator, StateTable>> data_;
  703. Label fst_id_;
  704. size_t offset_;
  705. Weight offset_weight_;
  706. std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_;
  707. bool error_;
  708. };
  709. // SafeReplaceAccumulator accumulates weights in a ReplaceFst and copies of it
  710. // are always thread-safe copies.
  711. template <class Accumulator, class T>
  712. class SafeReplaceAccumulator {
  713. public:
  714. using Arc = typename Accumulator::Arc;
  715. using StateId = typename Arc::StateId;
  716. using Label = typename Arc::Label;
  717. using Weight = typename Arc::Weight;
  718. using StateTable = T;
  719. using StateTuple = typename StateTable::StateTuple;
  720. SafeReplaceAccumulator() = default;
  721. SafeReplaceAccumulator(const SafeReplaceAccumulator &copy, bool safe)
  722. : SafeReplaceAccumulator(copy) {}
  723. explicit SafeReplaceAccumulator(
  724. const std::vector<Accumulator> &accumulators) {
  725. for (const auto &accumulator : accumulators) {
  726. accumulators_.emplace_back(accumulator, true);
  727. }
  728. }
  729. void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
  730. const StateTable *state_table) {
  731. state_table_ = state_table;
  732. for (Label i = 0; i < fst_tuples.size(); ++i) {
  733. if (i == accumulators_.size()) {
  734. accumulators_.resize(accumulators_.size() + 1);
  735. accumulators_[i].Init(*(fst_tuples[i].second));
  736. }
  737. fst_array_.emplace_back(fst_tuples[i].second->Copy(true));
  738. }
  739. init_ = true;
  740. }
  741. void Init(const Fst<Arc> &fst, bool copy = false) {
  742. if (!init_) {
  743. FSTERROR() << "SafeReplaceAccumulator::Init: Accumulator needs to be"
  744. << " initialized before being passed to LookAheadMatcher";
  745. error_ = true;
  746. }
  747. }
  748. void SetState(StateId s) {
  749. auto tuple = state_table_->Tuple(s);
  750. fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based
  751. GetAccumulator(fst_id_)->SetState(tuple.fst_state);
  752. offset_ = 0;
  753. offset_weight_ = Weight::Zero();
  754. const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state);
  755. if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) {
  756. offset_ = 1;
  757. offset_weight_ = final_weight;
  758. }
  759. aiter_.Set(*GetFst(fst_id_), tuple.fst_state);
  760. }
  761. Weight Sum(Weight w, Weight v) {
  762. if (error_) return Weight::NoWeight();
  763. return GetAccumulator(fst_id_)->Sum(w, v);
  764. }
  765. template <class ArcIter>
  766. Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
  767. if (error_) return Weight::NoWeight();
  768. if (begin == end) return Weight::Zero();
  769. auto sum = GetAccumulator(fst_id_)->Sum(
  770. w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_);
  771. if (begin == 0 && end != 0 && offset_ > 0) {
  772. sum = Sum(offset_weight_, sum);
  773. }
  774. return sum;
  775. }
  776. bool Error() const { return error_; }
  777. private:
  778. class ArcIteratorPtr {
  779. public:
  780. ArcIteratorPtr() = default;
  781. ArcIteratorPtr(const ArcIteratorPtr &copy) {}
  782. void Set(const Fst<Arc> &fst, StateId state_id) {
  783. ptr_ = std::make_unique<ArcIterator<Fst<Arc>>>(fst, state_id);
  784. }
  785. ArcIterator<Fst<Arc>> *get() { return ptr_.get(); }
  786. private:
  787. std::unique_ptr<ArcIterator<Fst<Arc>>> ptr_;
  788. };
  789. Accumulator *GetAccumulator(size_t i) { return &accumulators_[i]; }
  790. const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
  791. const StateTable *state_table_;
  792. std::vector<Accumulator> accumulators_;
  793. std::vector<std::shared_ptr<Fst<Arc>>> fst_array_;
  794. ArcIteratorPtr aiter_;
  795. bool init_ = false;
  796. bool error_ = false;
  797. Label fst_id_;
  798. size_t offset_;
  799. Weight offset_weight_;
  800. };
  801. } // namespace fst
  802. #endif // FST_ACCUMULATOR_H_