util.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. /*-----------------------------------------------------------------------------
  2. | Copyright (c) 2013-2017, Nucleic Development Team.
  3. |
  4. | Distributed under the terms of the Modified BSD License.
  5. |
  6. | The full license is in the file COPYING.txt, distributed with this software.
  7. |----------------------------------------------------------------------------*/
  8. #pragma once
  9. #include <map>
  10. #include <string>
  11. #include <Python.h>
  12. #include <kiwi/kiwi.h>
  13. #include "pythonhelpers.h"
  14. #include "types.h"
  15. inline bool
  16. convert_to_double( PyObject* obj, double& out )
  17. {
  18. if( PyFloat_Check( obj ) )
  19. {
  20. out = PyFloat_AS_DOUBLE( obj );
  21. return true;
  22. }
  23. #if PY_MAJOR_VERSION < 3
  24. if( PyInt_Check( obj ) )
  25. {
  26. out = double( PyInt_AsLong( obj ) );
  27. return true;
  28. }
  29. #endif
  30. if( PyLong_Check( obj ) )
  31. {
  32. out = PyLong_AsDouble( obj );
  33. if( out == -1.0 && PyErr_Occurred() )
  34. return false;
  35. return true;
  36. }
  37. PythonHelpers::py_expected_type_fail( obj, "float, int, or long" );
  38. return false;
  39. }
  40. inline bool
  41. convert_pystr_to_str( PyObject* value, std::string& out )
  42. {
  43. #if PY_MAJOR_VERSION >= 3
  44. out = PyUnicode_AsUTF8( value );
  45. #else
  46. if( PyUnicode_Check( value ) )
  47. {
  48. PythonHelpers::PyObjectPtr py_str( PyUnicode_AsUTF8String( value ) );
  49. if( !py_str )
  50. return false; // LCOV_EXCL_LINE
  51. out = PyString_AS_STRING( py_str.get() );
  52. }
  53. else
  54. out = PyString_AS_STRING( value );
  55. #endif
  56. return true;
  57. }
  58. inline bool
  59. convert_to_strength( PyObject* value, double& out )
  60. {
  61. #if PY_MAJOR_VERSION >= 3
  62. if( PyUnicode_Check( value ) )
  63. {
  64. #else
  65. if( PyString_Check( value ) | PyUnicode_Check( value ))
  66. {
  67. #endif
  68. std::string str;
  69. if( !convert_pystr_to_str( value, str ) )
  70. return false;
  71. if( str == "required" )
  72. out = kiwi::strength::required;
  73. else if( str == "strong" )
  74. out = kiwi::strength::strong;
  75. else if( str == "medium" )
  76. out = kiwi::strength::medium;
  77. else if( str == "weak" )
  78. out = kiwi::strength::weak;
  79. else
  80. {
  81. PyErr_Format(
  82. PyExc_ValueError,
  83. "string strength must be 'required', 'strong', 'medium', "
  84. "or 'weak', not '%s'",
  85. str.c_str()
  86. );
  87. return false;
  88. }
  89. return true;
  90. }
  91. if( !convert_to_double( value, out ) )
  92. return false;
  93. return true;
  94. }
  95. inline bool
  96. convert_to_relational_op( PyObject* value, kiwi::RelationalOperator& out )
  97. {
  98. #if PY_MAJOR_VERSION >= 3
  99. if( !PyUnicode_Check( value ) )
  100. {
  101. PythonHelpers::py_expected_type_fail( value, "unicode" );
  102. return false;
  103. }
  104. #else
  105. if( !(PyString_Check( value ) | PyUnicode_Check( value ) ) )
  106. {
  107. PythonHelpers::py_expected_type_fail( value, "str or unicode" );
  108. return false;
  109. }
  110. #endif
  111. std::string str;
  112. if( !convert_pystr_to_str( value, str ) )
  113. return false;
  114. if( str == "==" )
  115. out = kiwi::OP_EQ;
  116. else if( str == "<=" )
  117. out = kiwi::OP_LE;
  118. else if( str == ">=" )
  119. out = kiwi::OP_GE;
  120. else
  121. {
  122. PyErr_Format(
  123. PyExc_ValueError,
  124. "relational operator must be '==', '<=', or '>=', not '%s'",
  125. str.c_str()
  126. );
  127. return false;
  128. }
  129. return true;
  130. }
  131. inline PyObject*
  132. make_terms( const std::map<PyObject*, double>& coeffs )
  133. {
  134. typedef std::map<PyObject*, double>::const_iterator iter_t;
  135. PythonHelpers::PyObjectPtr terms( PyTuple_New( coeffs.size() ) );
  136. if( !terms )
  137. return 0;
  138. Py_ssize_t size = PyTuple_GET_SIZE( terms.get() );
  139. for( Py_ssize_t i = 0; i < size; ++i ) // zero tuple for safe early return
  140. PyTuple_SET_ITEM( terms.get(), i, 0 );
  141. Py_ssize_t i = 0;
  142. iter_t it = coeffs.begin();
  143. iter_t end = coeffs.end();
  144. for( ; it != end; ++it, ++i )
  145. {
  146. PyObject* pyterm = PyType_GenericNew( &Term_Type, 0, 0 );
  147. if( !pyterm )
  148. return 0;
  149. Term* term = reinterpret_cast<Term*>( pyterm );
  150. term->variable = PythonHelpers::newref( it->first );
  151. term->coefficient = it->second;
  152. PyTuple_SET_ITEM( terms.get(), i, pyterm );
  153. }
  154. return terms.release();
  155. }
  156. inline PyObject*
  157. reduce_expression( PyObject* pyexpr ) // pyexpr must be an Expression
  158. {
  159. Expression* expr = reinterpret_cast<Expression*>( pyexpr );
  160. std::map<PyObject*, double> coeffs;
  161. Py_ssize_t size = PyTuple_GET_SIZE( expr->terms );
  162. for( Py_ssize_t i = 0; i < size; ++i )
  163. {
  164. PyObject* item = PyTuple_GET_ITEM( expr->terms, i );
  165. Term* term = reinterpret_cast<Term*>( item );
  166. coeffs[ term->variable ] += term->coefficient;
  167. }
  168. PythonHelpers::PyObjectPtr terms( make_terms( coeffs ) );
  169. if( !terms )
  170. return 0;
  171. PyObject* pynewexpr = PyType_GenericNew( &Expression_Type, 0, 0 );
  172. if( !pynewexpr )
  173. return 0;
  174. Expression* newexpr = reinterpret_cast<Expression*>( pynewexpr );
  175. newexpr->terms = terms.release();
  176. newexpr->constant = expr->constant;
  177. return pynewexpr;
  178. }
  179. inline kiwi::Expression
  180. convert_to_kiwi_expression( PyObject* pyexpr ) // pyexpr must be an Expression
  181. {
  182. Expression* expr = reinterpret_cast<Expression*>( pyexpr );
  183. std::vector<kiwi::Term> kterms;
  184. Py_ssize_t size = PyTuple_GET_SIZE( expr->terms );
  185. for( Py_ssize_t i = 0; i < size; ++i )
  186. {
  187. PyObject* item = PyTuple_GET_ITEM( expr->terms, i );
  188. Term* term = reinterpret_cast<Term*>( item );
  189. Variable* var = reinterpret_cast<Variable*>( term->variable );
  190. kterms.push_back( kiwi::Term( var->variable, term->coefficient ) );
  191. }
  192. return kiwi::Expression( kterms, expr->constant );
  193. }
  194. inline const char*
  195. pyop_str( int op )
  196. {
  197. switch( op )
  198. {
  199. case Py_LT:
  200. return "<";
  201. case Py_LE:
  202. return "<=";
  203. case Py_EQ:
  204. return "==";
  205. case Py_NE:
  206. return "!=";
  207. case Py_GT:
  208. return ">";
  209. case Py_GE:
  210. return ">=";
  211. default:
  212. return "";
  213. }
  214. }