Source code for madcad.constraints

# This file is part of pymadcad,  distributed under license LGPL v3

''' This modules defines the constraints definitions and the solver tools

	Constraints can be any object referencing `variables` and implementing the following signature to guide the solver in the resolution:
		
		class SomeConstraint:
			
			# name the attributes referencing the solver variables to change
			slvvars = 'some_primitive', 'some_point'
			# function returning the squared error in the constraint
			
			#  for a coincident constraint for instance it's the squared distance
			#  the error must be a contiguous function of the parameters, and it's squared for numeric stability reasons.
			#  the solver internally use an iterative approach to optimize the sum of all fit functions.
			def fit((self):
				...
'''

from collections import Counter
import numpy as np
from scipy.optimize import least_squares
from .nprint import nprint
from .mathutils import *
from . import primitives
from . import displays, text, settings
from . import scheme
import array


class SolveError(Exception):	pass

class Constraint(object):
	def __init__(self, *args, **kwargs):
		for i,name in enumerate(self.__slots__):
			setattr(self, name, args[i] if i < len(args) else None)
		for name, arg in kwargs.items():
			setattr(self, name, arg)
	#def fitgrad(self):
		#varset = Varset()
		#for name in self.primitives:
			#varset.register(getattr(self, name))
		#return Derived(varset.state(), varset.grad(), varset.vars)

[docs]def isconstraint(obj): ''' Return True if obj match the constraint signature ''' return hasattr(obj, 'fit') and hasattr(obj, 'slvvars')
[docs]class Tangent(Constraint): ''' Makes to curves tangent in the given point The point moves as well as curves. The curves primitives must have a member ``slv_tangent(p)`` returning the tangent vector at the nearest position to `p` ''' __slots__ = 'c1', 'c2', 'p', 'size' slvvars = 'c1', 'c2', 'p' def fit(self): return length2(cross(self.c1.slv_tangent(self.p), self.c2.slv_tangent(self.p))),
#def display(self, scene): #return displays.TangentDisplay(scene, (self.p, self.c2.slv_tangent(self.p)), self.size)
[docs]class Distance(Constraint): ''' Makes two points distant of the fixed given distance ''' __slots__ = 'p1', 'p2', 'd', 'along', 'location' slvvars = 'p1', 'p2' def fit(self): if isinstance(self.p1, vec3) and isinstance(self.p2, vec3): if self.along: if isinstance(along, vec3): a = along else: a = along.direction return (dot(self.p1-self.p2, a) - d) ** 2, else: return (distance(self.p1, self.p2) - self.d) **2, elif isinstance(self.p1, vec3): return (length(noproject(self.p2.origin-self.p1, self.p2.direction)) - self.d) **2, elif isinstance(self.p2, vec3): return (length(noproject(self.p1.origin-self.p2, self.p1.direction)) - self.d) **2, else: d1 = self.p1.direction d2 = self.p2.direction if dot(d1,d2) < 0: d2 = -d2 return length2(cross(d1,d2)), (length(noproject(self.p1.origin-self.p2.origin, d1+d2)) - self.d) **2 #def fitgrad(self): #return derived.compose(dot, Derived(self.p1), Derived(self.p2)) def display(self, scene): if isinstance(self.p1, vec3) and isinstance(self.p2, vec3): return scene.display(scheme.note_distance( self.p1, self.p2, project=self.along, text='{:.5g}\n{:+.1g}'.format(self.d, sqrt(self.fit()[0])), )) elif isinstance(self.p1, vec3): return scene.display(scheme.note_distance( self.p1, self.p1 + noproject(self.p2.origin-self.p1, self.p2.direction), text='{:.5g}\n{:+.1g}'.format(self.d, sqrt(self.fit()[0])), )) elif isinstance(self.p2, vec3): return scene.display(scheme.note_distance( self.p2, self.p2 + noproject(self.p1.origin-self.p2, self.p1.direction), text='{:.5g}\n{:+.1g}'.format(self.d, sqrt(self.fit()[0])), )) else: p = mix(self.p1.origin, self.p2.origin, 0.5) d1 = self.p1.direction d2 = self.p2.direction if dot(d1,d2) < 0: d2 = -d2 d = d1 + d2 p1 = self.p1.origin p2 = self.p2.origin return scene.display(scheme.note_distance( p + noproject(p1-p, d), p + noproject(p2-p, d), text='{:.5g}\n{:+.1g}'.format(self.d, sqrt(self.fit()[0])), ))
[docs]class Angle(Constraint): ''' Gets two segments with the given fixed angle between them ''' __slots__ = 's1', 's2', 'angle' slvvars = 's1', 's2' def fit(self): d1 = self.s1.direction d2 = self.s2.direction a = atan2(length(cross(d1,d2)), dot(d1,d2)) return (a - self.angle)**2, def display(self, scene): return scene.display(scheme.note_angle( (self.s1.origin, -self.s1.direction), (self.s2.origin, -self.s2.direction), text='{:.5g}°\n{:+.1g}'.format(degrees(self.angle), degrees(sqrt(self.fit()[0]))), ))
class Parallel(Constraint): ''' Strict equivalent of Angle(s1,s2,0) ''' __slots__ = 's1', 's2' slvvars = __slots__ def fit(self): d1 = self.s1.direction d2 = self.s2.direction return length2(cross(d1,d2)),
[docs]class Radius(Constraint): ''' Gets the given Arc with the given fixed radius Note: Only ArcCentered are supported yet. ''' __slots__ = 'arc', 'radius', 'location' slvvars = 'arc', def fit(self): return (self.arc.radius - self.radius) **2, def display(self, scene): r = self.arc.radius center, z = self.arc.axis x = dirbase(z)[0] return scene.display(scheme.note_leading( center+x*r, r*x, text='R{:.5g}\n{:+.1g}'.format(r, self.radius - self.arc.radius), ))
[docs]class OnPlane(Constraint): ''' Puts the given points on the fixed plane given by its normal axis ''' __slots__ = 'axis', 'pts' def slvvars(self): return self.pts def fit(self): s = 0 for p in self.pts: yield dot(p-self.axis[0], self.axis[1]) **2
class PointOn(Constraint): ''' Puts the given point on the curve. The curve primitive must have a member ``slv_nearest(p) -> vec3`` returning the closest point to p on the curve. ''' __slots__ = 'point', 'curve' slvvars = 'point', 'curve' def fit(self): return distance2(self.curve.slv_nearest(self.point), self.point),
[docs]def solve(constraints, fixed=(), *args, **kwargs): ''' Short hand to use the class Problem ''' return Problem(constraints, fixed).solve(*args, **kwargs)
[docs]class Problem: ''' Class to holds data for a problem solving process. it is intended to be instantiated for each different probleme solving, an instance is used multiple times only when we want to solve on top of the previous results, using exactly the same probleme definition (constraints and variables) Therefore the solver protocol is the following: - constraints define the problem - each constraint refers to variables it applies on constraints have the method fit() and a member 'slvvars' that can be 1. An iterable of names of variable members in the constraint object 2. A function returning an iterable of the actual variables objects (that therefore must be referenced refs and not primitive types) - each variable object can redirect to other variable objects if they implements such a member 'slvvars' - primitives can also be constraints on their variables, thus they must have a method fit() (but no member 'primitives' here) - primitives can implement the optional solver methods for some constraints, such as 'slv_tangent' ''' def __init__(self, constraints, fixed=()): self.constraints = set() self.slvvars = {} self.dim = 0 for cst in constraints: # cst can contains objects that are not constraints if hasattr(cst, 'fit'): self.register(cst) for prim in fixed: self.unregister(prim) for v in self.slvvars.values(): if isinstance(v, tuple): self.dim += 1 else: self.dim += len(v)
[docs] def register(self, obj): ''' Register a constraint or a variable object ''' if hasattr(obj, 'fit'): self.constraints.add(obj.fit) # register object's variables if hasattr(obj, 'slvvars'): if callable(obj.slvvars): for var in obj.slvvars(): if isinstance(var, (float, int)): raise TypeError("primitive types (float,int) are not allowed when 'slvvars' is a callable") self.register(var) else: for varname in obj.slvvars: var = getattr(obj, varname) if isinstance(var, (float,int)): self.slvvars[(id(obj), varname)] = (obj, varname) else: self.register(var) elif isinstance(obj, tuple): for p in obj: self.register(p) else: if id(obj) not in self.slvvars: self.slvvars[id(obj)] = obj
[docs] def unregister(self, obj): ''' Unregister all variables from a constraint or a variable object ''' if hasattr(obj, 'slvvars'): if callable(obj.slvvars): for var in obj.slvvars(): self.unregister(var) for varname in obj.slvvars: var = getattr(obj, varname) if isinstance(var, (float,int)): del self.slvvars[(id(obj), varname)] else: self.unregister(var) elif isinstance(obj, tuple): for p in obj: self.unregister(p) else: if id(obj) in self.slvvars: del self.slvvars[id(obj)]
def state(self): x = np.empty(self.dim, dtype='f8') i = 0 for v in self.slvvars.values(): l = len(v) x[i:i+l] = v i += l return x def place(self, x): i = 0 for v in self.slvvars.values(): l = len(v) for j in range(l): v[j] = x[i+j] i += l def fit(self): residuals = array.array('d') for fit in self.constraints: residuals.extend(fit()) return residuals def evaluate(self, x): self.place(x) return self.fit() def solve(self, precision=1e-6, method='trf', maxiter=None, afterset=None): #nprint(self.slvvars) if afterset: evaluate = lambda x: afterset(x) or self.evaluate(x) else: evaluate = self.evaluate res = least_squares(evaluate, self.state(), xtol=precision, gtol=precision**3, ftol=precision**2, method=method, max_nfev=maxiter, ) self.place(res.x) #print(res) if res.cost < precision: return res elif not res.sucess: raise SolveError(res.message) else: raise SolveError('no solution found')
def solve2(constraints, precision=1e-4, afterset=None, fixed=(), maxiter=0): params = [] corrections = [] corrnorms = [] indices = [] knownparams = {} parts = [] mincorr = precision * 1e-2 # get parameters and prepare solver i = 0 for const in constraints: c = 0 for param in const.params(): k = id(param) if k not in fixed: c += 1 if k not in knownparams: knownparams[k] = i params.append(param) corrections.append(None) corrnorms.append(0) i += 1 indices.append(knownparams[k]) parts.append(c) # iterative resolution oldcorrs = [None] * len(corrections) maxdelta = inf it = 0 while maxdelta > precision: maxcorr = 0 maxdelta = 0 # initialize corrections for i in range(len(corrections)): corrections[i] = type(params[i])() corrnorms[i] = 0. # compute constraints contributions to corrections i = 0 for const,part in zip(constraints, parts): corr = const.corrections() for correction in corr: contrib = correction / part corrections[indices[i]] += contrib l = length(contrib) if l > corrnorms[indices[i]]: corrnorms[indices[i]] = l i += 1 # apply changes for param,correction,corrnorm,oldcorr in zip(params, corrections, corrnorms, oldcorrs): if id(param) not in fixed: l = length(correction) if l > maxcorr: maxcorr = l if corrnorm > maxdelta: maxdelta = corrnorm # filtre a oscillation if oldcorr: #print(' ', correction*0.5 + oldcorr*0.5) param += (correction*0.5 + oldcorr*0.5) else: param += correction * 0.8 oldcorrs = corrections[:] if afterset: afterset() it += 1 # check that the solver is solving if maxiter and it > maxiter: raise SolveError('failed to converge with the allowed iteration count: '+str(maxiter)) if maxdelta > precision and maxcorr < mincorr: raise SolveError('resolution is blocked (try an other initial state)') #print('iterations:', it-1) return maxdelta