row.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. /*-----------------------------------------------------------------------------
  2. | Copyright (c) 2013-2017, Nucleic Development Team.
  3. |
  4. | Distributed under the terms of the Modified BSD License.
  5. |
  6. | The full license is in the file LICENSE, distributed with this software.
  7. |----------------------------------------------------------------------------*/
  8. #pragma once
  9. #include "maptype.h"
  10. #include "symbol.h"
  11. #include "util.h"
  12. namespace kiwi
  13. {
  14. namespace impl
  15. {
  16. class Row
  17. {
  18. public:
  19. using CellMap = MapType<Symbol, double>;
  20. Row() : Row(0.0) {}
  21. Row(double constant) : m_constant(constant) {}
  22. Row(const Row &other) = default;
  23. ~Row() = default;
  24. const CellMap &cells() const
  25. {
  26. return m_cells;
  27. }
  28. double constant() const
  29. {
  30. return m_constant;
  31. }
  32. /* Add a constant value to the row constant.
  33. The new value of the constant is returned.
  34. */
  35. double add(double value)
  36. {
  37. return m_constant += value;
  38. }
  39. /* Insert a symbol into the row with a given coefficient.
  40. If the symbol already exists in the row, the coefficient will be
  41. added to the existing coefficient. If the resulting coefficient
  42. is zero, the symbol will be removed from the row.
  43. */
  44. void insert(const Symbol &symbol, double coefficient = 1.0)
  45. {
  46. if (nearZero(m_cells[symbol] += coefficient))
  47. m_cells.erase(symbol);
  48. }
  49. /* Insert a row into this row with a given coefficient.
  50. The constant and the cells of the other row will be multiplied by
  51. the coefficient and added to this row. Any cell with a resulting
  52. coefficient of zero will be removed from the row.
  53. */
  54. void insert(const Row &other, double coefficient = 1.0)
  55. {
  56. m_constant += other.m_constant * coefficient;
  57. for (const auto & cellPair : other.m_cells)
  58. {
  59. double coeff = cellPair.second * coefficient;
  60. if (nearZero(m_cells[cellPair.first] += coeff))
  61. m_cells.erase(cellPair.first);
  62. }
  63. }
  64. /* Remove the given symbol from the row.
  65. */
  66. void remove(const Symbol &symbol)
  67. {
  68. auto it = m_cells.find(symbol);
  69. if (it != m_cells.end())
  70. m_cells.erase(it);
  71. }
  72. /* Reverse the sign of the constant and all cells in the row.
  73. */
  74. void reverseSign()
  75. {
  76. m_constant = -m_constant;
  77. for (auto &cellPair : m_cells)
  78. cellPair.second = -cellPair.second;
  79. }
  80. /* Solve the row for the given symbol.
  81. This method assumes the row is of the form a * x + b * y + c = 0
  82. and (assuming solve for x) will modify the row to represent the
  83. right hand side of x = -b/a * y - c / a. The target symbol will
  84. be removed from the row, and the constant and other cells will
  85. be multiplied by the negative inverse of the target coefficient.
  86. The given symbol *must* exist in the row.
  87. */
  88. void solveFor(const Symbol &symbol)
  89. {
  90. double coeff = -1.0 / m_cells[symbol];
  91. m_cells.erase(symbol);
  92. m_constant *= coeff;
  93. for (auto &cellPair : m_cells)
  94. cellPair.second *= coeff;
  95. }
  96. /* Solve the row for the given symbols.
  97. This method assumes the row is of the form x = b * y + c and will
  98. solve the row such that y = x / b - c / b. The rhs symbol will be
  99. removed from the row, the lhs added, and the result divided by the
  100. negative inverse of the rhs coefficient.
  101. The lhs symbol *must not* exist in the row, and the rhs symbol
  102. *must* exist in the row.
  103. */
  104. void solveFor(const Symbol &lhs, const Symbol &rhs)
  105. {
  106. insert(lhs, -1.0);
  107. solveFor(rhs);
  108. }
  109. /* Get the coefficient for the given symbol.
  110. If the symbol does not exist in the row, zero will be returned.
  111. */
  112. double coefficientFor(const Symbol &symbol) const
  113. {
  114. CellMap::const_iterator it = m_cells.find(symbol);
  115. if (it == m_cells.end())
  116. return 0.0;
  117. return it->second;
  118. }
  119. /* Substitute a symbol with the data from another row.
  120. Given a row of the form a * x + b and a substitution of the
  121. form x = 3 * y + c the row will be updated to reflect the
  122. expression 3 * a * y + a * c + b.
  123. If the symbol does not exist in the row, this is a no-op.
  124. */
  125. void substitute(const Symbol &symbol, const Row &row)
  126. {
  127. auto it = m_cells.find(symbol);
  128. if (it != m_cells.end())
  129. {
  130. double coefficient = it->second;
  131. m_cells.erase(it);
  132. insert(row, coefficient);
  133. }
  134. }
  135. private:
  136. CellMap m_cells;
  137. double m_constant;
  138. };
  139. } // namespace impl
  140. } // namespace kiwi