Source code for orbit.py_linac.lattice_modifications.quad_overlap_modifications_lib

#!/usr/bin/env python

# --------------------------------------------------------------
# The functions for the lattice modifications to replace simple
# quads (class Quad) with objects that describe the fields of
# several quads with overlapping fields (class OverlappingQuadsNode).
# quadrupoles.
# These combine fields can be chopped into parts to include the
# elements with zero length. All others non-zero length elements
# can be only drifts.
# --------------------------------------------------------------

import math
import sys
import os
import time

# import from orbit Python utilities
from orbit.utils import orbitFinalize

from orbit.py_linac.lattice import Quad, Drift

from orbit.py_linac.lattice import OverlappingQuadsNode

from orbit.py_linac.lattice_modifications.rf_quad_overlap_modifications_lib import GetEngeFunction


[docs]def Replace_Quads_to_OverlappingQuads_Nodes(accLattice, z_step, accSeq_Names=[], quad_Names=[], EngeFunctionFactory=GetEngeFunction): """ Function will replace Quad nodes by OverlappingQuadsNode nodes. The replacement will be performed only for specified sequences. If the quad names list is empty, all of them will be replaced! z_step defines the longitudinal step during the tracking through the elements with overlapping fields. The magnetic field longitudinal dependency in quads will be described by Enge Functions that will be produced by the GetEngeFunction function by default. The user can supply his/her own factory for these functions. """ # ----------------------------------------------------------------------------- # ---- if the new drift will be shorter than drift_length_tolerance we ingnore it drift_length_tolerance = 0.00000001 node_pos_dict = accLattice.getNodePositionsDict() for accSeq_Name in accSeq_Names: accSeq = accLattice.getSequence(accSeq_Name) # print "debug ================== STAR seq=",accSeq.getName() if accSeq == None: msg = "The Replace_BaseRF_Gap_and_Quads_to_Overlapping_Nodes Python function. " msg += os.linesep msg += "Cannot find the acc. sequence with this name in the lattice!" msg += os.linesep msg = msg + "accSeq name = " + accSeq_Name msg = msg + os.linesep msg = msg + "lattice name = " + accLattice.getName() msg = msg + os.linesep orbitFinalize(msg) # -------------------------------------------------------------- nodes = accSeq.getNodes() # ---- the start and end positions of the accSeq in the lattice accSeq_z_start = node_pos_dict[nodes[0]][0] accSeq_z_end = node_pos_dict[nodes[len(nodes) - 1]][1] # ---- just for case: if the nodes are not in the right order nodes = sorted(nodes, key=lambda x: x.getPosition(), reverse=False) # ---- the node to index in the AccSeq dictionary node_to_index_in_seq_dict = {} for ind in range(len(nodes)): node_to_index_in_seq_dict[nodes[ind]] = ind # ---- quads list for replacement quads = accLattice.getQuads(accSeq) if len(quad_Names) > 0: quads_tmp = [] for quad in quads: if quad.getName() in quad_Names: quads_tmp.append(quad) quads = quads_tmp # ---- create Enge Functions' dictionary by using the usual quad nodes as keys enge_func_quad_dict = {} for quad in quads: enge_func_quad_dict[quad] = EngeFunctionFactory(quad) # -------------------------------------------- # ---- let's create the array with groups of overlaping quads # ---- quad_groups_and_ind_arr = [[quad0,quad1,..],pos_start,pos_end,ind_start,ind_end] quad_groups_and_ind_arr = Find_Groups_of_Quads(accLattice, accSeq, quads, enge_func_quad_dict, node_to_index_in_seq_dict) # -------------------------------------------- if len(quad_groups_and_ind_arr) == 0: return # ---- let's check that the ends of fields of groups of quads cover the drifts for group_ind in range(len(quad_groups_and_ind_arr)): [quads_arr, pos_start, pos_end, ind_start, ind_end] = quad_groups_and_ind_arr[group_ind] for node_ind in range(ind_start, ind_end + 1): node = nodes[node_ind] if (group_ind == 0 or group_ind == (len(quad_groups_and_ind_arr) - 1)) and node in quads_arr: (z0_min, z0_max) = enge_func_quad_dict[node].getLimitsZ() quad_pos = (node_pos_dict[node][0] + node_pos_dict[node][1]) / 2 delta_pos_quad_start = quad_pos + z0_min - accSeq_z_start delta_pos_quad_end = quad_pos + z0_max - accSeq_z_end res = delta_pos_quad_start < (-drift_length_tolerance) res = res or (delta_pos_quad_end > drift_length_tolerance) if res: msg = "The Replace_Quads_to_OverlappingQuads_Nodes function. STOP. " msg += os.linesep msg += "We have quad at the beginning or the end of the sequence! quad = " + node.getName() msg += os.linesep msg += "quad field positions[m] [from,to] =" + str((delta_pos_quad_start, delta_pos_quad_end)) msg += os.linesep msg += "The quad's distributed field will be outside the sequence!" msg += os.linesep msg += "The code cannot handle this situation at this moment." msg += os.linesep msg += "You have to change the EngeFunctionFactory to describe the right field function." msg += os.linesep orbitFinalize(msg) if node in quads_arr: continue if node.getLength() != 0 and (not isinstance(node, Drift)): msg = "The Replace_Quads_to_OverlappingQuads_Nodes function. STOP. " msg += os.linesep msg += "The field of the group of quads covers the elements with L != 0 and different from Drift! " msg += os.linesep msg += "At this moment we do not know how to handle this situation." msg += os.linesep quad_names = "" for quad in quads_arr: quad_names += quad.getName() + " " msg += "Quads = " + quad_names msg += os.linesep msg += "node=" + node.getName() msg += os.linesep msg += "Type = " + node.getType() orbitFinalize(msg) # ---------------------------------------------------------------------------- # ----------------------------------------------------- # ---- Now we are going to build a new lattice with OverlappingQuadsNode # ---- classes. The space covered by group of quads with overlapping fields # ---- will be represented by one or several OverlappingQuadsNode instances. # ---- We may need several OverlappingQuadsNode instances to include a zero # ---- length nodes. They will cut the covered space in parts. Each part will # ---- be represented by one OverlappingQuadsNode instance with the same # ---- field source (all overlapping quads). # ---- We will generate new nodes in an arbitrary order, but at the end # ---- we will sort them according to their position. # ----------------------------------------------------- new_nodes = [] n_groups = len(quad_groups_and_ind_arr) # ------------------------------------------------------------------------- # ---- 1st STEP - cut the length of drift nodes at the beginning and the end # ---- of quad groups for quad_group_ind in range(n_groups): [quads_arr, pos_start, pos_end, ind_start, ind_end] = quad_groups_and_ind_arr[quad_group_ind] # ---- if the node with index=ind_start or ind_end is a drift # ---- its length should be cut by the length of the overlapping # ---- region. These drifts are added to the new nodes node = nodes[ind_start] if isinstance(node, Drift): (node_pos_start, node_pos_end) = node_pos_dict[node] delta = node_pos_end - pos_start new_length = abs(node.getLength() - delta) # print "debug start group node = ",node.getName()," node.getLength() - delta =",node.getLength() - delta if new_length > drift_length_tolerance: if quad_group_ind > 0: # ---- we have to check that this node is not the end node of the previous quad group # ---- if it is true we have to skip it because we already accounted for this node # ---- earlier when we considered node at the end of the previous group [quads_arr0, pos_start0, pos_end0, ind_start0, ind_end0] = quad_groups_and_ind_arr[quad_group_ind - 1] if ind_start != ind_end0: node_new = Drift(node.getName()) node_new.setLength(new_length) node_new.setPosition(node.getPosition() - delta / 2) new_nodes.append(node_new) else: node_new = Drift(node.getName()) node_new.setLength(new_length) node_new.setPosition(node.getPosition() - delta / 2) new_nodes.append(node_new) node = nodes[ind_end] if isinstance(node, Drift): (node_pos_start, node_pos_end) = node_pos_dict[node] delta = pos_end - node_pos_start new_length = abs(node.getLength() - delta) # print "debug end group node = ",node.getName()," node.getLength() - delta =",node.getLength() - delta if new_length > drift_length_tolerance: if quad_group_ind < n_groups - 1: # ---- we have to check that this node is not the start node of the next quad group # ---- if it is true we have to account for the cat in the length from this group also [quads_arr1, pos_start1, pos_end1, ind_start1, ind_end1] = quad_groups_and_ind_arr[quad_group_ind + 1] if ind_end == ind_start1: delta1 = node_pos_end - pos_start1 new_length = abs(new_length - delta1) node_new = Drift(node.getName()) node_new.setLength(new_length) node_new.setPosition(node.getPosition() + delta / 2 - delta1 / 2) new_nodes.append(node_new) else: node_new = Drift(node.getName()) node_new.setLength(new_length) node_new.setPosition(node.getPosition() + delta / 2) new_nodes.append(node_new) else: node_new = Drift(node.getName()) node_new.setLength(new_length) node_new.setPosition(node.getPosition() + delta / 2) new_nodes.append(node_new) # ---- 2st STEP - nodes from the beginning to the first group of quads [quads_arr, pos_start, pos_end, ind_start, ind_end] = quad_groups_and_ind_arr[0] for node_ind in range(0, ind_start): # --- we keep the old positions new_nodes.append(nodes[node_ind]) # -------------------------------------------------------------------------- # ---- 3st STEP - nodes from the last quad group to the end of the Acc. Sequence [quads_arr, pos_start, pos_end, ind_start, ind_end] = quad_groups_and_ind_arr[n_groups - 1] for node_ind in range(ind_end + 1, len(nodes)): # --- we keep the old positions new_nodes.append(nodes[node_ind]) # -------------------------------------------------------------------------- # ---- 4st STEP - nodes between the quad groups for quad_group_ind in range(n_groups - 1): [quads_arr0, pos_start0, pos_end0, ind_start0, ind_end0] = quad_groups_and_ind_arr[quad_group_ind] [quads_arr1, pos_start1, pos_end1, ind_start1, ind_end1] = quad_groups_and_ind_arr[quad_group_ind + 1] for node_ind in range(ind_end0 + 1, ind_start1): # --- we keep the old positions new_nodes.append(nodes[node_ind]) # -------------------------------------------------------------------------- # ---- 5st STEP - create OverlappingQuadsNode nodes to cover all the quad groups for quad_group_ind in range(n_groups): [quads_arr, group_pos_start, group_pos_end, ind_start, ind_end] = quad_groups_and_ind_arr[quad_group_ind] (quads_tmp, zero_length_nodes) = Get_quads_zeroLengthNodes_in_range(accSeq, ind_start, ind_end) node_pos_start = group_pos_start - accSeq_z_start node_pos_end = group_pos_end - accSeq_z_start pos_start = node_pos_start pos_end = node_pos_end ovrlp_count = 0 for ind in range(len(zero_length_nodes) + 1): node = OverlappingQuadsNode() name = quads_arr[0].getName() + ":group:" + str(ovrlp_count + 1) + ":" + node.getType() node.setName(name) if ind == len(zero_length_nodes): pos_end = node_pos_end else: pos_end = zero_length_nodes[ind].getPosition() length = pos_end - pos_start if abs(length) < drift_length_tolerance: if ind < len(zero_length_nodes): new_nodes.append(zero_length_nodes[ind]) pos_start = pos_end continue pos = (pos_end + pos_start) / 2 node.setLength(length) node.setPosition(pos) for quad in quads_arr: node.addQuad(quad, enge_func_quad_dict[quad], quad.getPosition() - pos_start) node.setZ_Step(z_step) nParts = int(length / z_step) + 1 node.setnParts(nParts) new_nodes.append(node) ovrlp_count += 1 if ind < len(zero_length_nodes): new_nodes.append(zero_length_nodes[ind]) pos_start = pos_end # -------------------------------------------------------------------------- new_nodes = sorted(new_nodes, key=lambda x: x.getPosition(), reverse=False) # -------------------------------------------------------------------------- # ------------------------------------------ # ---- let's replace all nodes in the AccSeq by the new set accSeq.removeAllNodes() for node in new_nodes: accSeq.addNode(node) # ---- new set of nodes for the lattice new_latt_nodes = [] for accSeq in accLattice.getSequences(): new_latt_nodes += accSeq.getNodes() accLattice.setNodes(new_latt_nodes) accLattice.initialize() # ------- debug START printing of new nodes and their positions in the lattice """ node_pos_dict = accLattice.getNodePositionsDict() for accSeq_Name in accSeq_Names: accSeq = accLattice.getSequence(accSeq_Name) nodes = accSeq.getNodes() accSeq_z_start = node_pos_dict[nodes[0]][0] #-------------------------------------------------------------------------- for node in nodes: pos = node.getPosition() (pos_start,pos_end) = node_pos_dict[node] delta = pos - ((pos_start+pos_end)/2 - accSeq_z_start) if(abs(delta) > drift_length_tolerance): print "debug new node=",node.getName()," pos=",node.getPosition()," (pos_start,pos_end)=",node_pos_dict[node]," delta=",delta """
# ------- debug STOP printing of new nodes and their positions in the lattice def Find_Groups_of_Quads(accLattice, accSeq, quads, enge_func_quad_dict, node_to_index_in_seq_dict): """ This function will find the group of quads with the fields overlapping each other. It returns the quads that should be taken into account. quad_groups_and_ind_arr = [[quad0,quad1,..],pos_start,pos_end,ind_start,ind_end]] """ nodes = accSeq.getNodes() node_pos_dict = accLattice.getNodePositionsDict() n_quads = len(quads) quad_groups_and_ind_arr = [] quads_arr = [] for quad_ind in range(n_quads): quad0 = quads[quad_ind] (pos0_start, pos0_end) = node_pos_dict[quad0] (z0_min, z0_max) = enge_func_quad_dict[quad0].getLimitsZ() pos0_center = (pos0_start + pos0_end) / 2 quads_arr.append(quad0) if len(quads_arr) == 1: (node, index, posBefore, posAfter) = GetNodeInAccSeqForPosition(accLattice, nodes, pos0_center + z0_min) quad_groups_and_ind_arr.append([quads_arr, pos0_center + z0_min, 0.0, index, -1]) quad1 = quad0 if quad_ind != (n_quads - 1): quad1 = quads[quad_ind + 1] (pos1_start, pos1_end) = node_pos_dict[quad1] (z1_min, z1_max) = enge_func_quad_dict[quad1].getLimitsZ() pos1_center = (pos1_start + pos1_end) / 2 if pos0_center + z0_max < pos1_center + z1_min or quad1 == quad0: (node, index, posBefore, posAfter) = GetNodeInAccSeqForPosition(accLattice, nodes, pos0_center + z0_max) [arr_tmp, pos_start, pos_end, ind_start, ind_end] = quad_groups_and_ind_arr[len(quad_groups_and_ind_arr) - 1] pos_end = pos0_center + z0_max ind_end = index quad_groups_and_ind_arr[len(quad_groups_and_ind_arr) - 1] = [arr_tmp, pos_start, pos_end, ind_start, ind_end] quads_arr = [] # ----------------------------------- DEBUG PRINTING START """ for [quads_arr,pos_start,pos_end,ind_start,ind_end] in quad_groups_and_ind_arr: print "debug new group N quads=",len(quads_arr)," (pos_start,pos_end)=",(pos_start,pos_end)," (ind_start,ind_end)=",(ind_start,ind_end) for quad in quads_arr: (pos_start,pos_end) = node_pos_dict[quad] (z_min,z_max) = enge_func_quad_dict[quad].getLimitsZ() pos_center = (pos_start+pos_end)/2 print "debug quad = ",quad.getName()," (pos_center+z_min,pos_center+z_max) =",(pos_center+z_min,pos_center+z_max) """ # ----------------------------------- DEBUG PRINTING END return quad_groups_and_ind_arr def GetNodeInAccSeqForPosition(accLattice, nodes, z): """ It is a local convenience function. It returns the node in the AccSeq which coordinates cover the z-position. The position z is a position in the lattice. This function is different from the AccLattice method getNodeForPosition(z), because it limits the nodes' array to speed up the finding process. """ node_pos_dict = accLattice.getNodePositionsDict() index0 = 0 index1 = len(nodes) - 1 (posBefore0, posAfter0) = node_pos_dict[nodes[index0]] (posBefore1, posAfter1) = node_pos_dict[nodes[index1]] index = 0 if z <= posBefore0: return (nodes[index0], index0, posBefore0, posAfter0) if z >= posAfter1: return (nodes[index1], index1, posBefore1, posAfter1) while index0 != index1: index = (index0 + index1) // 2 # print "debug z=",z," index0=",index0," index1=",index1," index=",index," (posBefore0, posAfter0)=",(posBefore0, posAfter0)," (posBefore1, posAfter1)=",(posBefore1, posAfter1) (posBefore, posAfter) = node_pos_dict[nodes[index]] if z < posBefore: index1 = index (posBefore1, posAfter1) = node_pos_dict[nodes[index1]] elif z > posAfter: index0 = index (posBefore0, posAfter0) = node_pos_dict[nodes[index0]] elif z >= posBefore and z <= posAfter: break if z >= posBefore0 and z <= posAfter0: index = index0 break if z >= posBefore1 and z <= posAfter1: index = index1 break # ------------------------------------------- node = nodes[index] (posBefore, posAfter) = node_pos_dict[node] return (node, index, posBefore, posAfter) def Get_quads_zeroLengthNodes_in_range(accSeq, node_ind_start, node_ind_end): """ Returns all quads and zero-length nodes in this index range. It also checks that all elements inside this range has zero length or they are drifts of quads. """ nodes = accSeq.getNodes() zero_length_nodes = [] child_nodes = [] quads = [] for node_ind in range(node_ind_start, node_ind_end + 1): node = nodes[node_ind] children_arr = node.getBodyChildren() if len(children_arr) > 0: # print "debug ========= parent node=",node.getName()," pos = ",node.getPosition() for child in children_arr: if child.getLength() == 0.0: child_nodes.append(child) # print " debug child=",child.getName()," pos=",child.getPosition() length = node.getLength() if length == 0.0: zero_length_nodes.append(node) else: if isinstance(node, Quad): quads.append(node) else: if not isinstance(node, Drift): msg = "The Replace_BaseRF_Gap_and_Quads_to_Overlapping_Nodes function. " msg += "This Acc. Sequence has an element that " msg += os.linesep msg += "1. has non-zero length" msg += os.linesep msg += "2. not a quad" msg += os.linesep msg += "3. not a drift" msg += os.linesep msg += "This function does not know how to handle this elelement!" msg += os.linesep msg = msg + "Acc Sequence =" + accSeq.getName() msg = msg + os.linesep msg = msg + "Acc element =" + node.getName() msg = msg + os.linesep msg = msg + "Acc element type =" + node.getType() msg = msg + os.linesep orbitFinalize(msg) zero_length_nodes += child_nodes zero_length_nodes = sorted(zero_length_nodes, key=lambda x: x.getPosition(), reverse=False) return (quads, zero_length_nodes)