13 #ifndef MLPACK_METHODS_RL_SUMTREE_HPP
14 #define MLPACK_METHODS_RL_SUMTREE_HPP
46 SumTree(
const size_t capacity) : capacity(capacity)
48 element = std::vector<T>(2 * capacity);
57 void Set(
size_t idx,
const T value)
64 element[idx] = element[2 * idx] + element[2 * idx + 1];
75 void BatchUpdate(
const arma::ucolvec& indices,
const arma::Col<T>& data)
77 for (
size_t i = 0; i < indices.n_rows; ++i)
79 element[indices[i] + capacity] = data[i];
82 for (
size_t i = capacity - 1; i > 0; i--)
84 element[i] = element[2 * i] + element[2 * i + 1];
111 const size_t nodeStart,
112 const size_t nodeEnd)
114 if (start == nodeStart && end == nodeEnd)
116 return element[node];
118 size_t mid = (nodeStart + nodeEnd) / 2;
121 return SumHelper(start, end, 2 * node, nodeStart, mid);
125 if (mid + 1 <= start)
127 return SumHelper(start, end, 2 * node + 1, mid + 1 , nodeEnd);
131 return SumHelper(start, mid, 2 * node, nodeStart, mid) +
132 SumHelper(mid + 1, end, 2 * node + 1, mid + 1 , nodeEnd);
143 T
Sum(
const size_t start,
size_t end)
146 return SumHelper(start, end, 1, 0, capacity - 1);
154 return Sum(0, capacity);
166 while (idx < capacity)
168 if (element[2 * idx] > mass)
174 mass -= element[2 * idx];
178 return idx - capacity;
186 std::vector<T> element;