You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
2.8 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Function objects to restrict which arcs are traversed in an FST.
  19. #ifndef FST_ARCFILTER_H_
  20. #define FST_ARCFILTER_H_
  21. #include <fst/fst.h>
  22. #include <fst/util.h>
  23. namespace fst {
  24. // True for all arcs.
  25. template <class Arc>
  26. class AnyArcFilter {
  27. public:
  28. bool operator()(const Arc &arc) const { return true; }
  29. };
  30. // True for (input/output) epsilon arcs.
  31. template <class Arc>
  32. class EpsilonArcFilter {
  33. public:
  34. bool operator()(const Arc &arc) const {
  35. return arc.ilabel == 0 && arc.olabel == 0;
  36. }
  37. };
  38. // True for input epsilon arcs.
  39. template <class Arc>
  40. class InputEpsilonArcFilter {
  41. public:
  42. bool operator()(const Arc &arc) const { return arc.ilabel == 0; }
  43. };
  44. // True for output epsilon arcs.
  45. template <class Arc>
  46. class OutputEpsilonArcFilter {
  47. public:
  48. bool operator()(const Arc &arc) const { return arc.olabel == 0; }
  49. };
  50. // True if specified label matches (doesn't match) when keep_match is
  51. // true (false).
  52. template <class Arc>
  53. class LabelArcFilter {
  54. public:
  55. using Label = typename Arc::Label;
  56. explicit LabelArcFilter(Label label, bool match_input = true,
  57. bool keep_match = true)
  58. : label_(label), match_input_(match_input), keep_match_(keep_match) {}
  59. bool operator()(const Arc &arc) const {
  60. const bool match = (match_input_ ? arc.ilabel : arc.olabel) == label_;
  61. return keep_match_ ? match : !match;
  62. }
  63. private:
  64. const Label label_;
  65. const bool match_input_;
  66. const bool keep_match_;
  67. };
  68. // True if specified labels match (don't match) when keep_match is true (false).
  69. template <class Arc>
  70. class MultiLabelArcFilter {
  71. public:
  72. using Label = typename Arc::Label;
  73. explicit MultiLabelArcFilter(bool match_input = true, bool keep_match = true)
  74. : match_input_(match_input), keep_match_(keep_match) {}
  75. bool operator()(const Arc &arc) const {
  76. const Label label = match_input_ ? arc.ilabel : arc.olabel;
  77. const bool match = labels_.Find(label) != labels_.End();
  78. return keep_match_ ? match : !match;
  79. }
  80. void AddLabel(Label label) { labels_.Insert(label); }
  81. private:
  82. CompactSet<Label, kNoLabel> labels_;
  83. const bool match_input_;
  84. const bool keep_match_;
  85. };
  86. } // namespace fst
  87. #endif // FST_ARCFILTER_H_