You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

484 lines
16 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions and classes to relabel an FST (either on input or output).
  19. #ifndef FST_RELABEL_H_
  20. #define FST_RELABEL_H_
  21. #include <cstddef>
  22. #include <cstdint>
  23. #include <memory>
  24. #include <string>
  25. #include <utility>
  26. #include <vector>
  27. #include <fst/log.h>
  28. #include <fst/arc.h>
  29. #include <fst/cache.h>
  30. #include <fst/float-weight.h>
  31. #include <fst/fst.h>
  32. #include <fst/impl-to-fst.h>
  33. #include <fst/mutable-fst.h>
  34. #include <fst/properties.h>
  35. #include <fst/symbol-table.h>
  36. #include <fst/util.h>
  37. #include <unordered_map>
  38. namespace fst {
  39. // Relabels either the input labels or output labels. The old to
  40. // new labels are specified using a vector of std::pair<Label, Label>.
  41. // Any label associations not specified are assumed to be identity
  42. // mapping. The destination labels must be valid labels (e.g., not kNoLabel).
  43. template <class Arc>
  44. void Relabel(
  45. MutableFst<Arc> *fst,
  46. const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
  47. &ipairs,
  48. const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
  49. &opairs) {
  50. using Label = typename Arc::Label;
  51. const auto props = fst->Properties(kFstProperties, false);
  52. // Constructs label-to-label maps.
  53. const std::unordered_map<Label, Label> input_map(
  54. ipairs.begin(), ipairs.end());
  55. const std::unordered_map<Label, Label> output_map(
  56. opairs.begin(), opairs.end());
  57. for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
  58. siter.Next()) {
  59. for (MutableArcIterator<MutableFst<Arc>> aiter(fst, siter.Value());
  60. !aiter.Done(); aiter.Next()) {
  61. auto arc = aiter.Value();
  62. // dense_hash_map does not support find on the empty_key_val.
  63. // These labels should never be in an FST anyway.
  64. DCHECK_NE(arc.ilabel, kNoLabel);
  65. DCHECK_NE(arc.olabel, kNoLabel);
  66. // Relabels input.
  67. if (auto it = input_map.find(arc.ilabel); it != input_map.end()) {
  68. if (it->second == kNoLabel) {
  69. FSTERROR() << "Input symbol ID " << arc.ilabel
  70. << " missing from target vocabulary";
  71. fst->SetProperties(kError, kError);
  72. return;
  73. }
  74. arc.ilabel = it->second;
  75. }
  76. // Relabels output.
  77. if (auto it = output_map.find(arc.olabel); it != output_map.end()) {
  78. if (it->second == kNoLabel) {
  79. FSTERROR() << "Output symbol id " << arc.olabel
  80. << " missing from target vocabulary";
  81. fst->SetProperties(kError, kError);
  82. return;
  83. }
  84. arc.olabel = it->second;
  85. }
  86. aiter.SetValue(arc);
  87. }
  88. }
  89. fst->SetProperties(RelabelProperties(props), kFstProperties);
  90. }
  91. // Relabels either the input labels or output labels. The old to
  92. // new labels are specified using pairs of old and new symbol tables.
  93. // The tables must contain (at least) all labels on the appropriate side of the
  94. // FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any
  95. // missing symbol in new_i(o)symbols table.
  96. template <class Arc>
  97. void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
  98. const SymbolTable *new_isymbols,
  99. const std::string &unknown_isymbol, bool attach_new_isymbols,
  100. const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
  101. const std::string &unknown_osymbol, bool attach_new_osymbols) {
  102. using Label = typename Arc::Label;
  103. // Constructs vectors of input-side label pairs.
  104. std::vector<std::pair<Label, Label>> ipairs;
  105. if (old_isymbols && new_isymbols) {
  106. size_t num_missing_syms = 0;
  107. Label unknown_ilabel = kNoLabel;
  108. if (!unknown_isymbol.empty()) {
  109. unknown_ilabel = new_isymbols->Find(unknown_isymbol);
  110. if (unknown_ilabel == kNoLabel) {
  111. VLOG(1) << "Input symbol '" << unknown_isymbol
  112. << "' missing from target symbol table";
  113. ++num_missing_syms;
  114. }
  115. }
  116. for (const auto &sitem : *old_isymbols) {
  117. const auto old_index = sitem.Label();
  118. const auto symbol = sitem.Symbol();
  119. auto new_index = new_isymbols->Find(symbol);
  120. if (new_index == kNoLabel) {
  121. if (unknown_ilabel != kNoLabel) {
  122. new_index = unknown_ilabel;
  123. } else {
  124. VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol
  125. << "' missing from target symbol table";
  126. ++num_missing_syms;
  127. }
  128. }
  129. ipairs.emplace_back(old_index, new_index);
  130. }
  131. if (num_missing_syms > 0) {
  132. LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
  133. << " input symbols";
  134. }
  135. if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols);
  136. }
  137. // Constructs vectors of output-side label pairs.
  138. std::vector<std::pair<Label, Label>> opairs;
  139. if (old_osymbols && new_osymbols) {
  140. size_t num_missing_syms = 0;
  141. Label unknown_olabel = kNoLabel;
  142. if (!unknown_osymbol.empty()) {
  143. unknown_olabel = new_osymbols->Find(unknown_osymbol);
  144. if (unknown_olabel == kNoLabel) {
  145. VLOG(1) << "Output symbol '" << unknown_osymbol
  146. << "' missing from target symbol table";
  147. ++num_missing_syms;
  148. }
  149. }
  150. for (const auto &sitem : *old_osymbols) {
  151. const auto old_index = sitem.Label();
  152. const auto symbol = sitem.Symbol();
  153. auto new_index = new_osymbols->Find(symbol);
  154. if (new_index == kNoLabel) {
  155. if (unknown_olabel != kNoLabel) {
  156. new_index = unknown_olabel;
  157. } else {
  158. VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol
  159. << "' missing from target symbol table";
  160. ++num_missing_syms;
  161. }
  162. }
  163. opairs.emplace_back(old_index, new_index);
  164. }
  165. if (num_missing_syms > 0) {
  166. LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
  167. << " output symbols";
  168. }
  169. if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols);
  170. }
  171. // Calls relabel using vector of relabel pairs.
  172. Relabel(fst, ipairs, opairs);
  173. }
  174. // Same as previous but no special allowance for unknown symbols. Kept
  175. // for backward compat.
  176. template <class Arc>
  177. void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
  178. const SymbolTable *new_isymbols, bool attach_new_isymbols,
  179. const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
  180. bool attach_new_osymbols) {
  181. Relabel(fst, old_isymbols, new_isymbols, "" /* no unknown isymbol */,
  182. attach_new_isymbols, old_osymbols, new_osymbols,
  183. "" /* no unknown osymbol */, attach_new_osymbols);
  184. }
  185. // Relabels either the input labels or output labels. The old to
  186. // new labels are specified using symbol tables. Any label associations not
  187. // specified are assumed to be identity mapping.
  188. template <class Arc>
  189. void Relabel(MutableFst<Arc> *fst, const SymbolTable *new_isymbols,
  190. const SymbolTable *new_osymbols) {
  191. Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(),
  192. new_osymbols, true);
  193. }
  194. using RelabelFstOptions = CacheOptions;
  195. template <class Arc>
  196. class RelabelFst;
  197. namespace internal {
  198. // Relabels an FST from one symbol set to another. Relabeling can either be on
  199. // input or output space. RelabelFst implements a delayed version of the
  200. // relabel. Arcs are relabeled on the fly and not cached; i.e., each request is
  201. // recomputed.
  202. template <class Arc>
  203. class RelabelFstImpl : public CacheImpl<Arc> {
  204. public:
  205. using Label = typename Arc::Label;
  206. using StateId = typename Arc::StateId;
  207. using Weight = typename Arc::Weight;
  208. using Store = DefaultCacheStore<Arc>;
  209. using State = typename Store::State;
  210. using FstImpl<Arc>::SetType;
  211. using FstImpl<Arc>::SetProperties;
  212. using FstImpl<Arc>::WriteHeader;
  213. using FstImpl<Arc>::SetInputSymbols;
  214. using FstImpl<Arc>::SetOutputSymbols;
  215. using CacheImpl<Arc>::PushArc;
  216. using CacheImpl<Arc>::HasArcs;
  217. using CacheImpl<Arc>::HasFinal;
  218. using CacheImpl<Arc>::HasStart;
  219. using CacheImpl<Arc>::SetArcs;
  220. using CacheImpl<Arc>::SetFinal;
  221. using CacheImpl<Arc>::SetStart;
  222. friend class StateIterator<RelabelFst<Arc>>;
  223. RelabelFstImpl(const Fst<Arc> &fst,
  224. const std::vector<std::pair<Label, Label>> &ipairs,
  225. const std::vector<std::pair<Label, Label>> &opairs,
  226. const RelabelFstOptions &opts)
  227. : CacheImpl<Arc>(opts),
  228. fst_(fst.Copy()),
  229. input_map_(ipairs.begin(), ipairs.end()),
  230. output_map_(opairs.begin(), opairs.end()),
  231. relabel_input_(!ipairs.empty()),
  232. relabel_output_(!opairs.empty()) {
  233. SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false)));
  234. SetType("relabel");
  235. }
  236. RelabelFstImpl(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
  237. const SymbolTable *new_isymbols,
  238. const SymbolTable *old_osymbols,
  239. const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
  240. : CacheImpl<Arc>(opts),
  241. fst_(fst.Copy()),
  242. relabel_input_(false),
  243. relabel_output_(false) {
  244. SetType("relabel");
  245. SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false)));
  246. SetInputSymbols(old_isymbols);
  247. SetOutputSymbols(old_osymbols);
  248. if (old_isymbols && new_isymbols &&
  249. old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
  250. for (const auto &sitem : *old_isymbols) {
  251. input_map_[sitem.Label()] = new_isymbols->Find(sitem.Symbol());
  252. }
  253. SetInputSymbols(new_isymbols);
  254. relabel_input_ = true;
  255. }
  256. if (old_osymbols && new_osymbols &&
  257. old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
  258. for (const auto &sitem : *old_osymbols) {
  259. output_map_[sitem.Label()] = new_osymbols->Find(sitem.Symbol());
  260. }
  261. SetOutputSymbols(new_osymbols);
  262. relabel_output_ = true;
  263. }
  264. }
  265. RelabelFstImpl(const RelabelFstImpl<Arc> &impl)
  266. : CacheImpl<Arc>(impl),
  267. fst_(impl.fst_->Copy(true)),
  268. input_map_(impl.input_map_),
  269. output_map_(impl.output_map_),
  270. relabel_input_(impl.relabel_input_),
  271. relabel_output_(impl.relabel_output_) {
  272. SetType("relabel");
  273. SetProperties(impl.Properties(), kCopyProperties);
  274. SetInputSymbols(impl.InputSymbols());
  275. SetOutputSymbols(impl.OutputSymbols());
  276. }
  277. StateId Start() {
  278. if (!HasStart()) SetStart(fst_->Start());
  279. return CacheImpl<Arc>::Start();
  280. }
  281. Weight Final(StateId s) {
  282. if (!HasFinal(s)) SetFinal(s, fst_->Final(s));
  283. return CacheImpl<Arc>::Final(s);
  284. }
  285. size_t NumArcs(StateId s) {
  286. if (!HasArcs(s)) Expand(s);
  287. return CacheImpl<Arc>::NumArcs(s);
  288. }
  289. size_t NumInputEpsilons(StateId s) {
  290. if (!HasArcs(s)) Expand(s);
  291. return CacheImpl<Arc>::NumInputEpsilons(s);
  292. }
  293. size_t NumOutputEpsilons(StateId s) {
  294. if (!HasArcs(s)) Expand(s);
  295. return CacheImpl<Arc>::NumOutputEpsilons(s);
  296. }
  297. uint64_t Properties() const override { return Properties(kFstProperties); }
  298. // Sets error if found, and returns other FST impl properties.
  299. uint64_t Properties(uint64_t mask) const override {
  300. if ((mask & kError) && fst_->Properties(kError, false)) {
  301. SetProperties(kError, kError);
  302. }
  303. return FstImpl<Arc>::Properties(mask);
  304. }
  305. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
  306. if (!HasArcs(s)) Expand(s);
  307. CacheImpl<Arc>::InitArcIterator(s, data);
  308. }
  309. void Expand(StateId s) {
  310. for (ArcIterator<Fst<Arc>> aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
  311. auto arc = aiter.Value();
  312. if (relabel_input_) {
  313. if (auto it = input_map_.find(arc.ilabel); it != input_map_.end()) {
  314. arc.ilabel = it->second;
  315. }
  316. }
  317. if (relabel_output_) {
  318. if (auto it = output_map_.find(arc.olabel); it != output_map_.end()) {
  319. arc.olabel = it->second;
  320. }
  321. }
  322. PushArc(s, std::move(arc));
  323. }
  324. SetArcs(s);
  325. }
  326. private:
  327. std::unique_ptr<const Fst<Arc>> fst_;
  328. std::unordered_map<Label, Label> input_map_;
  329. std::unordered_map<Label, Label> output_map_;
  330. bool relabel_input_;
  331. bool relabel_output_;
  332. };
  333. } // namespace internal
  334. // This class attaches interface to implementation and handles
  335. // reference counting, delegating most methods to ImplToFst.
  336. template <class A>
  337. class RelabelFst : public ImplToFst<internal::RelabelFstImpl<A>> {
  338. public:
  339. using Arc = A;
  340. using Label = typename Arc::Label;
  341. using StateId = typename Arc::StateId;
  342. using Weight = typename Arc::Weight;
  343. using Store = DefaultCacheStore<Arc>;
  344. using State = typename Store::State;
  345. using Impl = internal::RelabelFstImpl<Arc>;
  346. friend class ArcIterator<RelabelFst<A>>;
  347. friend class StateIterator<RelabelFst<A>>;
  348. RelabelFst(const Fst<Arc> &fst,
  349. const std::vector<std::pair<Label, Label>> &ipairs,
  350. const std::vector<std::pair<Label, Label>> &opairs,
  351. const RelabelFstOptions &opts = RelabelFstOptions())
  352. : ImplToFst<Impl>(std::make_shared<Impl>(fst, ipairs, opairs, opts)) {}
  353. RelabelFst(const Fst<Arc> &fst, const SymbolTable *new_isymbols,
  354. const SymbolTable *new_osymbols,
  355. const RelabelFstOptions &opts = RelabelFstOptions())
  356. : ImplToFst<Impl>(
  357. std::make_shared<Impl>(fst, fst.InputSymbols(), new_isymbols,
  358. fst.OutputSymbols(), new_osymbols, opts)) {}
  359. RelabelFst(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
  360. const SymbolTable *new_isymbols, const SymbolTable *old_osymbols,
  361. const SymbolTable *new_osymbols,
  362. const RelabelFstOptions &opts = RelabelFstOptions())
  363. : ImplToFst<Impl>(std::make_shared<Impl>(fst, old_isymbols, new_isymbols,
  364. old_osymbols, new_osymbols,
  365. opts)) {}
  366. // See Fst<>::Copy() for doc.
  367. RelabelFst(const RelabelFst &fst, bool safe = false)
  368. : ImplToFst<Impl>(fst, safe) {}
  369. // Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc.
  370. RelabelFst *Copy(bool safe = false) const override {
  371. return new RelabelFst(*this, safe);
  372. }
  373. void InitStateIterator(StateIteratorData<Arc> *data) const override;
  374. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  375. return GetMutableImpl()->InitArcIterator(s, data);
  376. }
  377. private:
  378. using ImplToFst<Impl>::GetImpl;
  379. using ImplToFst<Impl>::GetMutableImpl;
  380. RelabelFst &operator=(const RelabelFst &) = delete;
  381. };
  382. // Specialization for RelabelFst.
  383. template <class Arc>
  384. class StateIterator<RelabelFst<Arc>> : public StateIteratorBase<Arc> {
  385. public:
  386. using StateId = typename Arc::StateId;
  387. explicit StateIterator(const RelabelFst<Arc> &fst)
  388. : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
  389. bool Done() const final { return siter_.Done(); }
  390. StateId Value() const final { return s_; }
  391. void Next() final {
  392. if (!siter_.Done()) {
  393. ++s_;
  394. siter_.Next();
  395. }
  396. }
  397. void Reset() final {
  398. s_ = 0;
  399. siter_.Reset();
  400. }
  401. private:
  402. const internal::RelabelFstImpl<Arc> *impl_;
  403. StateIterator<Fst<Arc>> siter_;
  404. StateId s_;
  405. StateIterator(const StateIterator &) = delete;
  406. StateIterator &operator=(const StateIterator &) = delete;
  407. };
  408. // Specialization for RelabelFst.
  409. template <class Arc>
  410. class ArcIterator<RelabelFst<Arc>> : public CacheArcIterator<RelabelFst<Arc>> {
  411. public:
  412. using StateId = typename Arc::StateId;
  413. ArcIterator(const RelabelFst<Arc> &fst, StateId s)
  414. : CacheArcIterator<RelabelFst<Arc>>(fst.GetMutableImpl(), s) {
  415. if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
  416. }
  417. };
  418. template <class Arc>
  419. inline void RelabelFst<Arc>::InitStateIterator(
  420. StateIteratorData<Arc> *data) const {
  421. data->base = std::make_unique<StateIterator<RelabelFst<Arc>>>(*this);
  422. }
  423. // Useful alias when using StdArc.
  424. using StdRelabelFst = RelabelFst<StdArc>;
  425. } // namespace fst
  426. #endif // FST_RELABEL_H_