dnn-layer-mathbinary.c 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. /*
  2. * Copyright (c) 2020
  3. *
  4. * This file is part of FFmpeg.
  5. *
  6. * FFmpeg is free software; you can redistribute it and/or
  7. * modify it under the terms of the GNU Lesser General Public
  8. * License as published by the Free Software Foundation; either
  9. * version 2.1 of the License, or (at your option) any later version.
  10. *
  11. * FFmpeg is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. * Lesser General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU Lesser General Public
  17. * License along with FFmpeg; if not, write to the Free Software
  18. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  19. */
  20. #include <stdio.h>
  21. #include <string.h>
  22. #include <math.h>
  23. #include "libavfilter/dnn/dnn_backend_native_layer_mathbinary.h"
  24. #include "libavutil/avassert.h"
  25. #define EPSON 0.00005
  26. static float get_expected(float f1, float f2, DNNMathBinaryOperation op)
  27. {
  28. switch (op)
  29. {
  30. case DMBO_SUB:
  31. return f1 - f2;
  32. case DMBO_ADD:
  33. return f1 + f2;
  34. case DMBO_MUL:
  35. return f1 * f2;
  36. case DMBO_REALDIV:
  37. return f1 / f2;
  38. case DMBO_MINIMUM:
  39. return (f1 < f2) ? f1 : f2;
  40. case DMBO_FLOORMOD:
  41. return (float)((int)(f1) % (int)(f2));
  42. default:
  43. av_assert0(!"not supported yet");
  44. return 0.f;
  45. }
  46. }
  47. static int test_broadcast_input0(DNNMathBinaryOperation op)
  48. {
  49. DnnLayerMathBinaryParams params;
  50. DnnOperand operands[2];
  51. int32_t input_indexes[1];
  52. float input[1*1*2*3] = {
  53. -3, 2.5, 2, -2.1, 7.8, 100
  54. };
  55. float *output;
  56. params.bin_op = op;
  57. params.input0_broadcast = 1;
  58. params.input1_broadcast = 0;
  59. params.v = 7.28;
  60. operands[0].data = input;
  61. operands[0].dims[0] = 1;
  62. operands[0].dims[1] = 1;
  63. operands[0].dims[2] = 2;
  64. operands[0].dims[3] = 3;
  65. operands[1].data = NULL;
  66. input_indexes[0] = 0;
  67. ff_dnn_execute_layer_math_binary(operands, input_indexes, 1, &params, NULL);
  68. output = operands[1].data;
  69. for (int i = 0; i < sizeof(input) / sizeof(float); i++) {
  70. float expected_output = get_expected(params.v, input[i], op);
  71. if (fabs(output[i] - expected_output) > EPSON) {
  72. printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
  73. op, i, output[i], expected_output, __FILE__, __LINE__);
  74. av_freep(&output);
  75. return 1;
  76. }
  77. }
  78. av_freep(&output);
  79. return 0;
  80. }
  81. static int test_broadcast_input1(DNNMathBinaryOperation op)
  82. {
  83. DnnLayerMathBinaryParams params;
  84. DnnOperand operands[2];
  85. int32_t input_indexes[1];
  86. float input[1*1*2*3] = {
  87. -3, 2.5, 2, -2.1, 7.8, 100
  88. };
  89. float *output;
  90. params.bin_op = op;
  91. params.input0_broadcast = 0;
  92. params.input1_broadcast = 1;
  93. params.v = 7.28;
  94. operands[0].data = input;
  95. operands[0].dims[0] = 1;
  96. operands[0].dims[1] = 1;
  97. operands[0].dims[2] = 2;
  98. operands[0].dims[3] = 3;
  99. operands[1].data = NULL;
  100. input_indexes[0] = 0;
  101. ff_dnn_execute_layer_math_binary(operands, input_indexes, 1, &params, NULL);
  102. output = operands[1].data;
  103. for (int i = 0; i < sizeof(input) / sizeof(float); i++) {
  104. float expected_output = get_expected(input[i], params.v, op);
  105. if (fabs(output[i] - expected_output) > EPSON) {
  106. printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
  107. op, i, output[i], expected_output, __FILE__, __LINE__);
  108. av_freep(&output);
  109. return 1;
  110. }
  111. }
  112. av_freep(&output);
  113. return 0;
  114. }
  115. static int test_no_broadcast(DNNMathBinaryOperation op)
  116. {
  117. DnnLayerMathBinaryParams params;
  118. DnnOperand operands[3];
  119. int32_t input_indexes[2];
  120. float input0[1*1*2*3] = {
  121. -3, 2.5, 2, -2.1, 7.8, 100
  122. };
  123. float input1[1*1*2*3] = {
  124. -1, 2, 3, -21, 8, 10.0
  125. };
  126. float *output;
  127. params.bin_op = op;
  128. params.input0_broadcast = 0;
  129. params.input1_broadcast = 0;
  130. operands[0].data = input0;
  131. operands[0].dims[0] = 1;
  132. operands[0].dims[1] = 1;
  133. operands[0].dims[2] = 2;
  134. operands[0].dims[3] = 3;
  135. operands[1].data = input1;
  136. operands[1].dims[0] = 1;
  137. operands[1].dims[1] = 1;
  138. operands[1].dims[2] = 2;
  139. operands[1].dims[3] = 3;
  140. operands[2].data = NULL;
  141. input_indexes[0] = 0;
  142. input_indexes[1] = 1;
  143. ff_dnn_execute_layer_math_binary(operands, input_indexes, 2, &params, NULL);
  144. output = operands[2].data;
  145. for (int i = 0; i < sizeof(input0) / sizeof(float); i++) {
  146. float expected_output = get_expected(input0[i], input1[i], op);
  147. if (fabs(output[i] - expected_output) > EPSON) {
  148. printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
  149. op, i, output[i], expected_output, __FILE__, __LINE__);
  150. av_freep(&output);
  151. return 1;
  152. }
  153. }
  154. av_freep(&output);
  155. return 0;
  156. }
  157. static int test(DNNMathBinaryOperation op)
  158. {
  159. if (test_broadcast_input0(op))
  160. return 1;
  161. if (test_broadcast_input1(op))
  162. return 1;
  163. if (test_no_broadcast(op))
  164. return 1;
  165. return 0;
  166. }
  167. int main(int argc, char **argv)
  168. {
  169. if (test(DMBO_SUB))
  170. return 1;
  171. if (test(DMBO_ADD))
  172. return 1;
  173. if (test(DMBO_MUL))
  174. return 1;
  175. if (test(DMBO_REALDIV))
  176. return 1;
  177. if (test(DMBO_MINIMUM))
  178. return 1;
  179. if (test(DMBO_FLOORMOD))
  180. return 1;
  181. return 0;
  182. }