gstlal  1.4.1
sumsquares_test_01.py
1 #!/usr/bin/env python
2 # Copyright (C) 2014 Kipp Cannon
3 #
4 # This program is free software; you can redistribute it and/or modify it
5 # under the terms of the GNU General Public License as published by the
6 # Free Software Foundation; either version 2 of the License, or (at your
7 # option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful, but
10 # WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
12 # Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along
15 # with this program; if not, write to the Free Software Foundation, Inc.,
16 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17 
18 #
19 # =============================================================================
20 #
21 # Preamble
22 #
23 # =============================================================================
24 #
25 
26 
27 import numpy
28 import sys
29 from gstlal import pipeparts
30 import test_common
31 import cmp_nxydumps
32 
33 
34 #
35 # =============================================================================
36 #
37 # Pipelines
38 #
39 # =============================================================================
40 #
41 
42 
43 #
44 # is the element an identity transform when given 1 channel of 1s and no
45 # weights array?
46 #
47 
48 
49 def sumsquares_test_01(pipeline, name, width):
50  #
51  # try changing these. test should still work!
52  #
53 
54  rate = 2048 # Hz
55  gap_frequency = 13.0 # Hz
56  gap_threshold = 0.8 # of 1
57  buffer_length = 1.0 # seconds
58  test_duration = 10.0 # seconds
59 
60  #
61  # build pipeline. square wave with 0 frequency = stream of 1s
62  #
63 
64  head = test_common.gapped_test_src(pipeline, buffer_length = buffer_length, rate = rate, width = width, wave = 1, freq = 0, channels = 1, test_duration = test_duration, gap_frequency = gap_frequency, gap_threshold = gap_threshold, control_dump_filename = "%s_control.dump" % name)
65  head = tee = pipeparts.mktee(pipeline, head)
66 
67  head = pipeparts.mksumsquares(pipeline, head)
68  head = pipeparts.mkchecktimestamps(pipeline, head)
69  pipeparts.mknxydumpsink(pipeline, pipeparts.mkqueue(pipeline, head), "%s_out.dump" % name)
70  pipeparts.mknxydumpsink(pipeline, pipeparts.mkqueue(pipeline, tee), "%s_in.dump" % name)
71 
72  #
73  # done
74  #
75 
76  return pipeline
77 
78 
79 #
80 # test the transformation of a specific buffer with a specific weights vector
81 #
82 
83 
84 def sumsquares_test_02(name, dtype, samples, channels_in, sample_fuzz = cmp_nxydumps.default_sample_fuzz):
85  numpy.random.seed(0)
86  input_array = numpy.random.random((samples, channels_in)).astype(dtype)
87  # element always ingests weights matrix as double-precision floats
88  weights = numpy.random.random((channels_in,)).astype("float64")
89 
90  output_reference = ((weights.astype(dtype) * input_array)**2).sum(axis = 1)
91  output_reference.shape = output_reference.shape + (1,)
92 
93  output_array, = test_common.transform_arrays([input_array], pipeparts.mksumsquares, name, weights = weights)
94 
95  residual = abs((output_array - output_reference))
96  if residual[residual > sample_fuzz].any():
97  raise ValueError("incorrect output: expected %s, got %s\ndifference = %s" % (output_reference, output_array, residual))
98 
99 
100 #
101 # =============================================================================
102 #
103 # Main
104 #
105 # =============================================================================
106 #
107 
108 
109 test_common.build_and_run(sumsquares_test_01, "sumsquares_test_01a", width = 64)
110 cmp_nxydumps.compare("sumsquares_test_01a_in.dump", "sumsquares_test_01a_out.dump", flags = cmp_nxydumps.COMPARE_FLAGS_EXACT_GAPS)
111 test_common.build_and_run(sumsquares_test_01, "sumsquares_test_01b", width = 32)
112 cmp_nxydumps.compare("sumsquares_test_01b_in.dump", "sumsquares_test_01b_out.dump", flags = cmp_nxydumps.COMPARE_FLAGS_EXACT_GAPS)
113 
114 
115 sumsquares_test_02("sumsquares_test_02a", "float64", samples = 6, channels_in = 4, sample_fuzz = cmp_nxydumps.default_sample_fuzz)
116 sumsquares_test_02("sumsquares_test_02b", "float32", samples = 6, channels_in = 4, sample_fuzz = cmp_nxydumps.default_sample_fuzz**.5)
def compare(filename1, filename2, args, kwargs)
def build_and_run(pipelinefunc, name, segment=None, pipelinefunc_kwargs)
Definition: test_common.py:118
def gapped_test_src(pipeline, buffer_length=1.0, rate=2048, width=64, channels=1, test_duration=10.0, wave=5, freq=0, gap_frequency=None, gap_threshold=None, control_dump_filename=None, tags=None, is_live=False, verbose=True)
Definition: test_common.py:95
def transform_arrays(input_arrays, elemfunc, name, rate=1, elemfunc_kwargs)
Definition: test_common.py:141