wrapper.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. #include "wrapper.h"
  2. #include <util/datetime/cputimer.h>
  3. #include <util/stream/buffered.h>
  4. #include <util/stream/buffer.h>
  5. #include <util/stream/format.h>
  6. #include <util/stream/input.h>
  7. #include <util/stream/mem.h>
  8. #include <util/stream/output.h>
  9. #include <util/system/sys_alloc.h>
  10. namespace {
  11. class TLuaCountLimit {
  12. public:
  13. TLuaCountLimit(lua_State* state, int count)
  14. : State(state)
  15. {
  16. lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count);
  17. }
  18. ~TLuaCountLimit() {
  19. lua_sethook(State, LuaHookCallback, 0, 0);
  20. }
  21. static void LuaHookCallback(lua_State* L, lua_Debug*) {
  22. luaL_error(L, "Lua instruction count limit exceeded");
  23. }
  24. private:
  25. lua_State* State;
  26. }; // class TLuaCountLimit
  27. class TLuaTimeLimit {
  28. public:
  29. TLuaTimeLimit(lua_State* state, TDuration limit, int count)
  30. : State(state)
  31. , Limit(limit)
  32. {
  33. lua_pushlightuserdata(State, (void*)LuaHookCallback); //key
  34. lua_pushlightuserdata(State, (void*)this); //value
  35. lua_settable(State, LUA_REGISTRYINDEX);
  36. lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count);
  37. }
  38. ~TLuaTimeLimit() {
  39. lua_sethook(State, LuaHookCallback, 0, 0);
  40. }
  41. bool Exceeded() {
  42. return Timer.Get() > Limit;
  43. }
  44. static void LuaHookCallback(lua_State* L, lua_Debug*) {
  45. lua_pushlightuserdata(L, (void*)LuaHookCallback);
  46. lua_gettable(L, LUA_REGISTRYINDEX);
  47. TLuaTimeLimit* t = static_cast<TLuaTimeLimit*>(lua_touserdata(L, -1));
  48. lua_pop(L, 1);
  49. if (t->Exceeded()) {
  50. luaL_error(L, "time limit exceeded");
  51. }
  52. }
  53. private:
  54. lua_State* State;
  55. const TDuration Limit;
  56. TSimpleTimer Timer;
  57. }; // class TLuaTimeLimit
  58. class TLuaReader {
  59. public:
  60. TLuaReader(IZeroCopyInput* in)
  61. : In_(in)
  62. {
  63. }
  64. inline void Load(lua_State* state, const char* name) {
  65. if (lua_load(state, ReadCallback, this, name
  66. #if LUA_VERSION_NUM > 501
  67. ,
  68. nullptr
  69. #endif
  70. ))
  71. {
  72. ythrow TLuaStateHolder::TError() << "can not parse lua chunk " << name << ": " << lua_tostring(state, -1);
  73. }
  74. }
  75. static const char* ReadCallback(lua_State*, void* data, size_t* size) {
  76. return ((TLuaReader*)(data))->Read(size);
  77. }
  78. private:
  79. inline const char* Read(size_t* readed) {
  80. const char* ret;
  81. if (*readed = In_->Next(&ret)) {
  82. return ret;
  83. }
  84. return nullptr;
  85. }
  86. private:
  87. IZeroCopyInput* In_;
  88. }; // class TLuaReader
  89. class TLuaWriter {
  90. public:
  91. TLuaWriter(IOutputStream* out)
  92. : Out_(out)
  93. {
  94. }
  95. inline void Dump(lua_State* state) {
  96. if (lua_dump(state, WriteCallback, this)) {
  97. ythrow TLuaStateHolder::TError() << "can not dump lua state: " << lua_tostring(state, -1);
  98. }
  99. }
  100. static int WriteCallback(lua_State*, const void* data, size_t size, void* user) {
  101. return ((TLuaWriter*)(user))->Write(data, size);
  102. }
  103. private:
  104. inline int Write(const void* data, size_t size) {
  105. Out_->Write(static_cast<const char*>(data), size);
  106. return 0;
  107. }
  108. private:
  109. IOutputStream* Out_;
  110. }; // class TLuaWriter
  111. } //namespace
  112. void TLuaStateHolder::Load(IInputStream* in, TZtStringBuf name) {
  113. TBufferedInput wi(in, 8192);
  114. return TLuaReader(&wi).Load(State_, name.c_str());
  115. }
  116. void TLuaStateHolder::Dump(IOutputStream* out) {
  117. return TLuaWriter(out).Dump(State_);
  118. }
  119. void TLuaStateHolder::DumpStack(IOutputStream* out) {
  120. for (int i = lua_gettop(State_) * -1; i < 0; ++i) {
  121. *out << i << " is " << lua_typename(State_, lua_type(State_, i)) << " (";
  122. if (is_number(i)) {
  123. *out << to_number<long long>(i);
  124. } else if (is_string(i)) {
  125. *out << to_string(i);
  126. } else {
  127. *out << Hex((uintptr_t)lua_topointer(State_, i), HF_ADDX);
  128. }
  129. *out << ')' << Endl;
  130. }
  131. }
  132. void* TLuaStateHolder::Alloc(void* ud, void* ptr, size_t /*osize*/, size_t nsize) {
  133. (void)ud;
  134. if (nsize == 0) {
  135. y_deallocate(ptr);
  136. return nullptr;
  137. }
  138. return y_reallocate(ptr, nsize);
  139. }
  140. void* TLuaStateHolder::AllocLimit(void* ud, void* ptr, size_t osize, size_t nsize) {
  141. TLuaStateHolder& state = *static_cast<TLuaStateHolder*>(ud);
  142. if (nsize == 0) {
  143. y_deallocate(ptr);
  144. state.AllocFree += osize;
  145. return nullptr;
  146. }
  147. if (state.AllocFree + osize < nsize) {
  148. return nullptr;
  149. }
  150. ptr = y_reallocate(ptr, nsize);
  151. if (ptr) {
  152. state.AllocFree += osize;
  153. state.AllocFree -= nsize;
  154. }
  155. return ptr;
  156. }
  157. void TLuaStateHolder::call(int args, int rets, int count) {
  158. TLuaCountLimit limit(State_, count);
  159. return call(args, rets);
  160. }
  161. void TLuaStateHolder::call(int args, int rets, TDuration time_limit, int count) {
  162. TLuaTimeLimit limit(State_, time_limit, count);
  163. return call(args, rets);
  164. }
  165. template <>
  166. void Out<NLua::TStackDumper>(IOutputStream& out, const NLua::TStackDumper& sd) {
  167. sd.State.DumpStack(&out);
  168. }
  169. template <>
  170. void Out<NLua::TMarkedStackDumper>(IOutputStream& out, const NLua::TMarkedStackDumper& sd) {
  171. out << sd.Mark << Endl;
  172. sd.State.DumpStack(&out);
  173. out << sd.Mark << Endl;
  174. }
  175. namespace NLua {
  176. TBuffer& Compile(TStringBuf script, TBuffer& buffer) {
  177. TMemoryInput input(script.data(), script.size());
  178. TLuaStateHolder state;
  179. state.Load(&input, "main");
  180. TBufferOutput out(buffer);
  181. state.Dump(&out);
  182. return buffer;
  183. }
  184. }